mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-17 19:05:13 +02:00
IAM tech spec: Auth and access management current state and proposed
changes. Support for separate workspaces Addition of workspace CLI support for test purposes
This commit is contained in:
parent
48da6c5f8b
commit
db05427d0e
219 changed files with 4875 additions and 2616 deletions
|
|
@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
|
|||
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
|
||||
"""Test basic agent integration with structured query tool"""
|
||||
# Arrange - Load tool configuration
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
# Create agent request
|
||||
request = AgentRequest(
|
||||
|
|
@ -119,6 +119,7 @@ Args: {
|
|||
# Mock flow parameter in agent_processor.on_request
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -146,7 +147,7 @@ Args: {
|
|||
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent handling of structured query errors"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Find data from a table that doesn't exist using structured query.",
|
||||
|
|
@ -199,6 +200,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -221,7 +223,7 @@ Args: {
|
|||
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent using structured query in multi-step reasoning"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="First find all customers from California, then tell me how many orders they have made.",
|
||||
|
|
@ -279,6 +281,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -313,7 +316,7 @@ Args: {
|
|||
}
|
||||
}
|
||||
|
||||
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
|
||||
await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Query the sales data for recent transactions.",
|
||||
|
|
@ -371,6 +374,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -394,10 +398,10 @@ Args: {
|
|||
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query tool arguments are properly validated"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
# Check that the tool was registered with correct arguments
|
||||
tools = agent_processor.agent.tools
|
||||
tools = agent_processor.agents["default"].tools
|
||||
assert "structured-query" in tools
|
||||
|
||||
structured_tool = tools["structured-query"]
|
||||
|
|
@ -414,7 +418,7 @@ Args: {
|
|||
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query results are properly formatted for agent consumption"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Get customer information and format it nicely.",
|
||||
|
|
@ -482,6 +486,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Set up schemas
|
||||
proc.schemas = sample_schemas
|
||||
proc.schemas = {"default": dict(sample_schemas)}
|
||||
|
||||
# Mock the client method
|
||||
proc.client = MagicMock()
|
||||
|
|
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
|
|||
}
|
||||
|
||||
# Act - Update configuration
|
||||
await integration_processor.on_schema_config(new_schema_config, "v2")
|
||||
await integration_processor.on_schema_config("default", new_schema_config, "v2")
|
||||
|
||||
# Arrange - Test query using new schema
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
|
|
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
|
|||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
assert "inventory" in integration_processor.schemas
|
||||
assert "inventory" in integration_processor.schemas["default"]
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
assert response.detected_schemas == ["inventory"]
|
||||
|
|
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
|
|||
graphql_generation_template="custom-graphql-generator"
|
||||
)
|
||||
|
||||
custom_processor.schemas = sample_schemas
|
||||
custom_processor.schemas = {"default": dict(sample_schemas)}
|
||||
custom_processor.client = MagicMock()
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
|
|
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
|
|||
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
|
||||
)
|
||||
|
||||
integration_processor.schemas.update(large_schema_set)
|
||||
integration_processor.schemas["default"].update(large_schema_set)
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Show me data from table_05 and table_12",
|
||||
|
|
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
|
|||
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
|
||||
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
|
|||
|
|
@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
return context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Assert
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" in processor.schemas
|
||||
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 2
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" in ws_schemas
|
||||
|
||||
# Verify customer schema
|
||||
customer_schema = processor.schemas["customer_records"]
|
||||
customer_schema = ws_schemas["customer_records"]
|
||||
assert customer_schema.name == "customer_records"
|
||||
assert len(customer_schema.fields) == 4
|
||||
|
||||
|
||||
# Verify product schema
|
||||
product_schema = processor.schemas["product_catalog"]
|
||||
product_schema = ws_schemas["product_catalog"]
|
||||
assert product_schema.name == "product_catalog"
|
||||
assert len(product_schema.fields) == 4
|
||||
|
||||
|
|
@ -237,7 +239,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create realistic customer data chunk
|
||||
metadata = Metadata(
|
||||
|
|
@ -304,7 +306,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create realistic product data chunk
|
||||
metadata = Metadata(
|
||||
|
|
@ -368,7 +370,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create multiple test chunks
|
||||
chunks_data = [
|
||||
|
|
@ -431,19 +433,21 @@ class TestObjectExtractionServiceIntegration:
|
|||
"customer_records": integration_config["schema"]["customer_records"]
|
||||
}
|
||||
}
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
|
||||
assert len(processor.schemas) == 1
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" not in processor.schemas
|
||||
|
||||
await processor.on_schema_config("default", initial_config, version=1)
|
||||
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 1
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" not in ws_schemas
|
||||
|
||||
# Act - Reload with full configuration
|
||||
await processor.on_schema_config(integration_config, version=2)
|
||||
|
||||
await processor.on_schema_config("default", integration_config, version=2)
|
||||
|
||||
# Assert
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" in processor.schemas
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 2
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" in ws_schemas
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_resilience_integration(self, integration_config):
|
||||
|
|
@ -474,10 +478,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
return AsyncMock()
|
||||
|
||||
failing_flow.side_effect = failing_context_router
|
||||
failing_flow.workspace = "default"
|
||||
processor.flow = failing_flow
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create test chunk
|
||||
metadata = Metadata(id="error-test", user="test", collection="test")
|
||||
|
|
@ -510,7 +515,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create chunk with rich metadata
|
||||
original_metadata = Metadata(
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -109,7 +110,7 @@ class TestPromptStreaming:
|
|||
def prompt_processor_streaming(self, mock_prompt_manager):
|
||||
"""Create Prompt processor with streaming support"""
|
||||
processor = MagicMock()
|
||||
processor.manager = mock_prompt_manager
|
||||
processor.managers = {"default": mock_prompt_manager}
|
||||
processor.config_key = "prompt"
|
||||
|
||||
# Bind the actual on_request method
|
||||
|
|
@ -248,6 +249,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
|
||||
request = PromptRequest(
|
||||
id="test_prompt",
|
||||
|
|
@ -341,6 +343,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
|
||||
request = PromptRequest(
|
||||
id="test_prompt",
|
||||
|
|
|
|||
|
|
@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
|
|||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
|
||||
|
||||
class _MockFlowDefault:
|
||||
"""Mock Flow with default workspace for testing."""
|
||||
workspace = "default"
|
||||
name = "default"
|
||||
id = "test-processor"
|
||||
|
||||
|
||||
mock_flow_default = _MockFlowDefault()
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRowsCassandraIntegration:
|
||||
"""Integration tests for Cassandra row storage with unified table"""
|
||||
|
|
@ -125,8 +136,8 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
assert "customer_records" in processor.schemas["default"]
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
|
|
@ -149,7 +160,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
assert mock_cluster.connect.called
|
||||
|
|
@ -209,8 +220,8 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
assert len(processor.schemas["default"]) == 2
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
|
|
@ -233,7 +244,7 @@ class TestRowsCassandraIntegration:
|
|||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# All data goes into the same unified rows table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -256,15 +267,17 @@ class TestRowsCassandraIntegration:
|
|||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Schema with multiple indexed fields
|
||||
processor.schemas["indexed_data"] = RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"indexed_data": RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
|
|
@ -282,7 +295,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 data inserts (one per indexed field: id, category, status)
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -342,7 +355,7 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
|
|
@ -376,7 +389,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify unified table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -396,10 +409,12 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"empty_test": RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
}
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
|
|
@ -413,7 +428,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should not create any data insert statements for empty batch
|
||||
# (partition registration may still happen)
|
||||
|
|
@ -428,14 +443,16 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["map_test"] = RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"map_test": RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
|
|
@ -448,7 +465,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify insert uses map for data
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -473,13 +490,15 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["partition_test"] = RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"partition_test": RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="my_collection"),
|
||||
|
|
@ -492,7 +511,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify partition registration
|
||||
partition_inserts = [call for call in mock_session.execute.call_args_list
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
|
||||
"""Test schema configuration loading and GraphQL schema generation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Verify schemas were loaded
|
||||
assert len(processor.schemas) == 2
|
||||
|
|
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
|
||||
"""Test Cassandra connection and dynamic table creation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Connect to Cassandra
|
||||
processor.connect_cassandra()
|
||||
|
|
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
|
||||
"""Test inserting data and querying via GraphQL"""
|
||||
# Load schema and connect
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Setup test data
|
||||
|
|
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
|
||||
"""Test GraphQL queries with filtering on indexed fields"""
|
||||
# Setup (reuse previous setup)
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "test_user"
|
||||
|
|
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_graphql_error_handling(self, processor, sample_schema_config):
|
||||
"""Test GraphQL error handling for invalid queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Test invalid field query
|
||||
invalid_query = '''
|
||||
|
|
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_message_processing_integration(self, processor, sample_schema_config):
|
||||
"""Test full message processing workflow"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
|
|
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_concurrent_queries(self, processor, sample_schema_config):
|
||||
"""Test handling multiple concurrent GraphQL queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create multiple query tasks
|
||||
|
|
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
await processor.on_schema_config("default", initial_config, version=1)
|
||||
assert len(processor.schemas) == 1
|
||||
assert "simple" in processor.schemas
|
||||
|
||||
|
|
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(updated_config, version=2)
|
||||
await processor.on_schema_config("default", updated_config, version=2)
|
||||
|
||||
# Verify updated schemas
|
||||
assert len(processor.schemas) == 2
|
||||
|
|
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_large_result_set_handling(self, processor, sample_schema_config):
|
||||
"""Test handling of large query result sets"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "large_test_user"
|
||||
|
|
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
await processor.on_schema_config("default", schema_config, version=1)
|
||||
|
||||
# Measure query execution time
|
||||
start_time = time.time()
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to call think and observe callbacks
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -58,6 +61,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
@ -129,6 +133,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to return Final directly
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -148,6 +155,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return "tool result"
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -260,6 +260,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
|
||||
|
|
@ -280,7 +281,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return {"data": [1, 2, 3]}
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -298,6 +299,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -317,7 +319,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def failing_invoke(name, params):
|
||||
async def failing_invoke(workspace, name, params):
|
||||
raise RuntimeError("tool broke")
|
||||
|
||||
svc.invoke_tool = failing_invoke
|
||||
|
|
@ -330,6 +332,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -350,7 +353,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def rate_limited(name, params):
|
||||
async def rate_limited(workspace, name, params):
|
||||
raise TooManyRequests("slow down")
|
||||
|
||||
svc.invoke_tool = rate_limited
|
||||
|
|
@ -362,6 +365,7 @@ class TestToolServiceOnRequest:
|
|||
flow = MagicMock()
|
||||
flow.producer = {"response": AsyncMock()}
|
||||
flow.name = "test-flow"
|
||||
flow.workspace = "default"
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await svc.on_request(msg, MagicMock(), flow)
|
||||
|
|
@ -376,7 +380,8 @@ class TestToolServiceOnRequest:
|
|||
|
||||
received = {}
|
||||
|
||||
async def capture_invoke(name, params):
|
||||
async def capture_invoke(workspace, name, params):
|
||||
received["workspace"] = workspace
|
||||
received["name"] = name
|
||||
received["params"] = params
|
||||
return "ok"
|
||||
|
|
@ -390,6 +395,7 @@ class TestToolServiceOnRequest:
|
|||
flow = lambda name: mock_pub
|
||||
flow.producer = {"response": mock_pub}
|
||||
flow.name = "f"
|
||||
flow.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
"""
|
||||
Tests for AsyncProcessor config notify pattern:
|
||||
- register_config_handler with types filtering
|
||||
- on_config_notify version comparison and type matching
|
||||
- fetch_config with short-lived client
|
||||
- fetch_and_apply_config retry logic
|
||||
- on_config_notify version comparison, type/workspace matching
|
||||
- fetch_and_apply_config retry logic over per-workspace fetches
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, Mock
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
# Patch heavy dependencies before importing AsyncProcessor
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
"""Create an AsyncProcessor with mocked dependencies."""
|
||||
|
|
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
|
|||
assert len(processor.config_handlers) == 2
|
||||
|
||||
|
||||
def _notify_msg(version, changes):
|
||||
"""Build a Mock config-notify message with given version and changes dict."""
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=version, changes=changes)
|
||||
return msg
|
||||
|
||||
|
||||
class TestOnConfigNotify:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -77,9 +81,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=3, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(3, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -91,9 +93,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=5, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(5, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -105,9 +105,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["schema"])
|
||||
|
||||
msg = _notify_msg(2, {"schema": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -121,40 +119,36 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
# Mock fetch_config
|
||||
mock_config = {"prompt": {"key": "value"}}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={"key": "value"},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
handler.assert_called_once_with(
|
||||
"default", {"prompt": {"key": "value"}}, 2
|
||||
)
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_without_types_always_called(self, processor):
|
||||
async def test_handler_without_types_ignored_on_notify(self, processor):
|
||||
"""Handlers registered without types never fire on notifications."""
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler) # No types = all
|
||||
processor.register_config_handler(handler) # No types
|
||||
|
||||
mock_config = {"anything": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["whatever"])
|
||||
msg = _notify_msg(2, {"whatever": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
handler.assert_not_called()
|
||||
# Version still advances past the notify
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_handlers_type_filtering(self, processor):
|
||||
|
|
@ -168,156 +162,149 @@ class TestOnConfigNotify:
|
|||
processor.register_config_handler(schema_handler, types=["schema"])
|
||||
processor.register_config_handler(all_handler)
|
||||
|
||||
mock_config = {"prompt": {}}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
prompt_handler.assert_called_once()
|
||||
prompt_handler.assert_called_once_with(
|
||||
"default", {"prompt": {}}, 2
|
||||
)
|
||||
schema_handler.assert_not_called()
|
||||
all_handler.assert_called_once()
|
||||
all_handler.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_types_invokes_all(self, processor):
|
||||
"""Empty types list (startup signal) should invoke all handlers."""
|
||||
async def test_multi_workspace_notify_invokes_handler_per_ws(
|
||||
self, processor
|
||||
):
|
||||
"""Notify affecting multiple workspaces invokes handler once per workspace."""
|
||||
processor.config_version = 1
|
||||
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
mock_config = {}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=[])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
h1.assert_called_once()
|
||||
h2.assert_called_once()
|
||||
assert handler.call_count == 2
|
||||
called_workspaces = {c.args[0] for c in handler.call_args_list}
|
||||
assert called_workspaces == {"ws1", "ws2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_failure_handled(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler)
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("Connection failed")
|
||||
side_effect=RuntimeError("Connection failed"),
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
# Should not raise
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
|
||||
class TestFetchConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_returns_config_and_version(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.config = {"prompt": {"key": "val"}}
|
||||
mock_resp.version = 42
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
config, version = await processor.fetch_config()
|
||||
|
||||
assert config == {"prompt": {"key": "val"}}
|
||||
assert version == 42
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_raises_on_error_response(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = Mock(message="not found")
|
||||
mock_resp.config = {}
|
||||
mock_resp.version = 0
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="Config error"):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_stops_client_on_exception(self, processor):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = TimeoutError("timeout")
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(TimeoutError):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestFetchAndApplyConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_applies_config_to_all_handlers(self, processor):
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
async def test_applies_config_per_workspace(self, processor):
|
||||
"""Startup fetch invokes handler once per workspace affected."""
|
||||
h = AsyncMock()
|
||||
processor.register_config_handler(h, types=["prompt"])
|
||||
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
return {
|
||||
"ws1": {"k": "v1"},
|
||||
"ws2": {"k": "v2"},
|
||||
}, 10
|
||||
|
||||
mock_config = {"prompt": {}, "schema": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 10)
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
# On startup, all handlers are invoked regardless of type
|
||||
h1.assert_called_once_with(mock_config, 10)
|
||||
h2.assert_called_once_with(mock_config, 10)
|
||||
assert h.call_count == 2
|
||||
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
|
||||
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
|
||||
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
|
||||
assert processor.config_version == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
call_count = 0
|
||||
mock_config = {"prompt": {}}
|
||||
async def test_handler_without_types_skipped_at_startup(self, processor):
|
||||
"""Handlers registered without types fetch nothing at startup."""
|
||||
typed = AsyncMock()
|
||||
untyped = AsyncMock()
|
||||
processor.register_config_handler(typed, types=["prompt"])
|
||||
processor.register_config_handler(untyped)
|
||||
|
||||
async def mock_fetch():
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
return {"default": {}}, 1
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
typed.assert_called_once()
|
||||
untyped.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
h = AsyncMock()
|
||||
processor.register_config_handler(h, types=["prompt"])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError("not ready")
|
||||
return mock_config, 5
|
||||
return {"default": {"k": "v"}}, 5
|
||||
|
||||
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
|
||||
patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
), patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
assert call_count == 3
|
||||
assert processor.config_version == 5
|
||||
h.assert_called_once_with(
|
||||
"default", {"prompt": {"k": "v"}}, 5
|
||||
)
|
||||
|
|
|
|||
|
|
@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs():
|
|||
spec_two = MagicMock()
|
||||
processor = MagicMock(specifications=[spec_one, spec_two])
|
||||
|
||||
flow = Flow("processor-1", "flow-a", processor, {"answer": 42})
|
||||
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
|
||||
|
||||
assert flow.id == "processor-1"
|
||||
assert flow.name == "flow-a"
|
||||
assert flow.workspace == "default"
|
||||
assert flow.producer == {}
|
||||
assert flow.consumer == {}
|
||||
assert flow.parameter == {}
|
||||
|
|
@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs():
|
|||
def test_flow_start_and_stop_visit_all_consumers():
|
||||
consumer_one = AsyncMock()
|
||||
consumer_two = AsyncMock()
|
||||
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
|
||||
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
|
||||
flow.consumer = {"one": consumer_one, "two": consumer_two}
|
||||
|
||||
asyncio.run(flow.start())
|
||||
|
|
@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers():
|
|||
|
||||
|
||||
def test_flow_call_returns_values_in_priority_order():
|
||||
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
|
||||
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
|
||||
flow.producer["shared"] = "producer-value"
|
||||
flow.consumer["consumer-only"] = "consumer-value"
|
||||
flow.consumer["shared"] = "consumer-value"
|
||||
|
|
|
|||
|
|
@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
|
|||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
await processor.start_flow("default", flow_name, flow_defn)
|
||||
|
||||
# Assert - Flow should be created with access to processor specifications
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
|
||||
|
||||
# The flow should have access to the processor's specifications
|
||||
# (The exact mechanism depends on Flow implementation)
|
||||
|
|
|
|||
|
|
@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
await processor.start_flow("default", flow_name, flow_defn)
|
||||
|
||||
assert flow_name in processor.flows
|
||||
assert ("default", flow_name) in processor.flows
|
||||
mock_flow_class.assert_called_once_with(
|
||||
'test-processor', flow_name, processor, flow_defn
|
||||
'test-processor', flow_name, "default", processor, flow_defn
|
||||
)
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
|
|
@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
await processor.start_flow(flow_name, {'config': 'test-config'})
|
||||
await processor.start_flow("default", flow_name, {'config': 'test-config'})
|
||||
|
||||
await processor.stop_flow(flow_name)
|
||||
await processor.stop_flow("default", flow_name)
|
||||
|
||||
assert flow_name not in processor.flows
|
||||
assert ("default", flow_name) not in processor.flows
|
||||
mock_flow.stop.assert_called_once()
|
||||
|
||||
@with_async_processor_patches
|
||||
|
|
@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
await processor.stop_flow('non-existent-flow')
|
||||
await processor.stop_flow("default", 'non-existent-flow')
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert 'test-flow' in processor.flows
|
||||
assert ("default", 'test-flow') in processor.flows
|
||||
mock_flow_class.assert_called_once_with(
|
||||
'test-processor', 'test-flow', processor,
|
||||
'test-processor', 'test-flow', "default", processor,
|
||||
{'config': 'test-config'}
|
||||
)
|
||||
mock_flow.start.assert_called_once()
|
||||
|
|
@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
'other-data': 'some-value'
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data1, version=1)
|
||||
await processor.on_configure_flows("default", config_data1, version=1)
|
||||
|
||||
config_data2 = {
|
||||
'processor:test-processor': {
|
||||
|
|
@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data2, version=2)
|
||||
await processor.on_configure_flows("default", config_data2, version=2)
|
||||
|
||||
assert 'flow1' not in processor.flows
|
||||
assert ("default", 'flow1') not in processor.flows
|
||||
mock_flow1.stop.assert_called_once()
|
||||
|
||||
assert 'flow2' in processor.flows
|
||||
assert ("default", 'flow2') in processor.flows
|
||||
mock_flow2.start.assert_called_once()
|
||||
|
||||
@with_async_processor_patches
|
||||
|
|
|
|||
|
|
@ -109,7 +109,8 @@ class TestListConfigItems:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
format_type='json',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_list_main_uses_defaults(self):
|
||||
|
|
@ -128,7 +129,8 @@ class TestListConfigItems:
|
|||
url='http://localhost:8088/',
|
||||
config_type='prompt',
|
||||
format_type='text',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -196,7 +198,8 @@ class TestGetConfigItem:
|
|||
config_type='prompt',
|
||||
key='template-1',
|
||||
format_type='json',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -253,7 +256,8 @@ class TestPutConfigItem:
|
|||
config_type='prompt',
|
||||
key='new-template',
|
||||
value='Custom prompt: {input}',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_put_main_with_stdin_arg(self):
|
||||
|
|
@ -278,7 +282,8 @@ class TestPutConfigItem:
|
|||
config_type='prompt',
|
||||
key='stdin-template',
|
||||
value=stdin_content,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_put_main_mutually_exclusive_args(self):
|
||||
|
|
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='old-template',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
|
|||
# Verify Api was created with correct parameters
|
||||
mock_api_class.assert_called_once_with(
|
||||
url="http://test.example.com/",
|
||||
token="test-token"
|
||||
token="test-token",
|
||||
workspace="default"
|
||||
)
|
||||
|
||||
# Verify bulk client was obtained
|
||||
|
|
|
|||
|
|
@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
|
|||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_set_main_structured_query_no_arguments_needed(self):
|
||||
|
|
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
|
|||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_valid_types_includes_row_embeddings_query(self):
|
||||
|
|
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
|
|||
|
||||
show_main()
|
||||
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None)
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
|
||||
|
||||
|
||||
class TestShowToolsRowEmbeddingsQuery:
|
||||
|
|
|
|||
|
|
@ -28,10 +28,12 @@ def mock_flow_config():
|
|||
"""Mock flow configuration."""
|
||||
mock_config = Mock()
|
||||
mock_config.flows = {
|
||||
"test-flow": {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test-triples-queue"},
|
||||
"graph-embeddings-store": {"flow": "test-ge-queue"}
|
||||
"test-user": {
|
||||
"test-flow": {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test-triples-queue"},
|
||||
"graph-embeddings-store": {"flow": "test-ge-queue"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
await processor.on_schema_config("default", config_data, 1)
|
||||
|
||||
assert 'customers' in processor.schemas
|
||||
assert processor.schemas['customers'].name == 'customers'
|
||||
assert len(processor.schemas['customers'].fields) == 3
|
||||
assert 'customers' in processor.schemas["default"]
|
||||
assert processor.schemas["default"]['customers'].name == 'customers'
|
||||
assert len(processor.schemas["default"]['customers'].fields) == 3
|
||||
|
||||
async def test_on_schema_config_handles_missing_type(self):
|
||||
"""Test that missing schema type is handled gracefully"""
|
||||
|
|
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
'other_type': {}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
await processor.on_schema_config("default", config_data, 1)
|
||||
|
||||
assert processor.schemas == {}
|
||||
assert processor.schemas.get("default", {}) == {}
|
||||
|
||||
async def test_on_message_drops_unknown_collection(self):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
|
|
@ -325,14 +325,16 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Set up schema
|
||||
processor.schemas['customers'] = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
'customers': RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
|
|
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
return MagicMock()
|
||||
|
||||
mock_flow = MagicMock(side_effect=flow_factory)
|
||||
mock_flow.workspace = "default"
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
|
|||
ConfigReceiver.config_loader = Mock()
|
||||
|
||||
|
||||
def _notify(version, changes):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=version, changes=changes)
|
||||
return msg
|
||||
|
||||
|
||||
class TestConfigReceiver:
|
||||
"""Test cases for ConfigReceiver class"""
|
||||
|
||||
|
|
@ -47,98 +53,70 @@ class TestConfigReceiver:
|
|||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_new_version(self):
|
||||
"""Test on_config_notify triggers fetch for newer version"""
|
||||
async def test_on_config_notify_new_version_fetches_per_workspace(self):
|
||||
"""Notify with newer version fetches each affected workspace."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
# Mock fetch_and_apply
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with newer version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=["flow"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
assert len(fetch_calls) == 1
|
||||
msg = _notify(2, {"flow": ["ws1", "ws2"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert set(fetch_calls) == {"ws1", "ws2"}
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_old_version_ignored(self):
|
||||
"""Test on_config_notify ignores older versions"""
|
||||
"""Older-version notifies are ignored."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 5
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with older version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=3, types=["flow"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
assert len(fetch_calls) == 0
|
||||
msg = _notify(3, {"flow": ["ws1"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert fetch_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_irrelevant_types_ignored(self):
|
||||
"""Test on_config_notify ignores types the gateway doesn't care about"""
|
||||
"""Notifies without flow changes advance version but skip fetch."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with non-flow type
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
# Version should be updated but no fetch
|
||||
assert len(fetch_calls) == 0
|
||||
msg = _notify(2, {"prompt": ["ws1"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert fetch_calls == []
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_flow_type_triggers_fetch(self):
|
||||
"""Test on_config_notify fetches for flow-related types"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
for type_name in ["flow"]:
|
||||
fetch_calls.clear()
|
||||
config_receiver.config_version = 1
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=[type_name])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_exception_handling(self):
|
||||
"""Test on_config_notify handles exceptions gracefully"""
|
||||
"""on_config_notify swallows exceptions from message decode."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create notify message that causes an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
|
|
@ -146,19 +124,18 @@ class TestConfigReceiver:
|
|||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_new_flows(self):
|
||||
"""Test fetch_and_apply starts new flows"""
|
||||
async def test_fetch_and_apply_workspace_starts_new_flows(self):
|
||||
"""fetch_and_apply_workspace starts newly-configured flows."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock _create_config_client to return a mock client
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
"flow2": '{"name": "test_flow_2"}'
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -167,36 +144,39 @@ class TestConfigReceiver:
|
|||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
start_flow_calls = []
|
||||
async def mock_start_flow(id, flow):
|
||||
start_flow_calls.append((id, flow))
|
||||
|
||||
async def mock_start_flow(workspace, id, flow):
|
||||
start_flow_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert config_receiver.config_version == 5
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow1" in config_receiver.flows["default"]
|
||||
assert "flow2" in config_receiver.flows["default"]
|
||||
assert len(start_flow_calls) == 2
|
||||
assert all(c[0] == "default" for c in start_flow_calls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_removed_flows(self):
|
||||
"""Test fetch_and_apply stops removed flows"""
|
||||
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
|
||||
"""fetch_and_apply_workspace stops flows no longer configured."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
"default": {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"},
|
||||
}
|
||||
}
|
||||
|
||||
# Config now only has flow1
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}'
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -205,20 +185,22 @@ class TestConfigReceiver:
|
|||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
stop_flow_calls = []
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_flow_calls.append((id, flow))
|
||||
|
||||
async def mock_stop_flow(workspace, id, flow):
|
||||
stop_flow_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
assert "flow1" in config_receiver.flows["default"]
|
||||
assert "flow2" not in config_receiver.flows["default"]
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0][0] == "flow2"
|
||||
assert stop_flow_calls[0][:2] == ("default", "flow2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_no_flows(self):
|
||||
"""Test fetch_and_apply with empty config"""
|
||||
async def test_fetch_and_apply_workspace_with_no_flows(self):
|
||||
"""Empty workspace config clears any local flow state."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
|
@ -231,88 +213,100 @@ class TestConfigReceiver:
|
|||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.flows.get("default", {}) == {}
|
||||
assert config_receiver.config_version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
"""start_flow fans out to every registered flow handler."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler1.start_flow = AsyncMock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
handler2.start_flow = AsyncMock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
await config_receiver.start_flow("default", "flow1", flow_data)
|
||||
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler1.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
handler2.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handler_exception(self):
|
||||
"""Test start_flow method handles handler exceptions"""
|
||||
"""Handler exceptions in start_flow do not propagate."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
await config_receiver.start_flow("default", "flow1", flow_data)
|
||||
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handlers(self):
|
||||
"""Test stop_flow method with multiple handlers"""
|
||||
"""stop_flow fans out to every registered flow handler."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler1.stop_flow = AsyncMock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
handler2.stop_flow = AsyncMock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
await config_receiver.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler1.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
handler2.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handler_exception(self):
|
||||
"""Test stop_flow method handles handler exceptions"""
|
||||
"""Handler exceptions in stop_flow do not propagate."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
await config_receiver.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -329,25 +323,25 @@ class TestConfigReceiver:
|
|||
mock_create_task.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_mixed_flow_operations(self):
|
||||
"""Test fetch_and_apply with mixed add/remove operations"""
|
||||
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
|
||||
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
"default": {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"},
|
||||
}
|
||||
}
|
||||
|
||||
# Config removes flow1, keeps flow2, adds flow3
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
"flow3": '{"name": "test_flow_3"}'
|
||||
"flow3": '{"name": "test_flow_3"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -358,20 +352,22 @@ class TestConfigReceiver:
|
|||
start_calls = []
|
||||
stop_calls = []
|
||||
|
||||
async def mock_start_flow(id, flow):
|
||||
start_calls.append((id, flow))
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_calls.append((id, flow))
|
||||
async def mock_start_flow(workspace, id, flow):
|
||||
start_calls.append((workspace, id, flow))
|
||||
|
||||
async def mock_stop_flow(workspace, id, flow):
|
||||
stop_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
ws_flows = config_receiver.flows["default"]
|
||||
assert "flow1" not in ws_flows
|
||||
assert "flow2" in ws_flows
|
||||
assert "flow3" in ws_flows
|
||||
assert len(start_calls) == 1
|
||||
assert start_calls[0][0] == "flow3"
|
||||
assert start_calls[0][:2] == ("default", "flow3")
|
||||
assert len(stop_calls) == 1
|
||||
assert stop_calls[0][0] == "flow1"
|
||||
assert stop_calls[0][:2] == ("default", "flow1")
|
||||
|
|
|
|||
|
|
@ -72,10 +72,10 @@ class TestDispatcherManager:
|
|||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await manager.start_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" in manager.flows
|
||||
assert manager.flows["flow1"] == flow_data
|
||||
await manager.start_flow("default", "flow1", flow_data)
|
||||
|
||||
assert ("default", "flow1") in manager.flows
|
||||
assert manager.flows[("default", "flow1")] == flow_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow(self):
|
||||
|
|
@ -86,11 +86,11 @@ class TestDispatcherManager:
|
|||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
manager.flows["flow1"] = flow_data
|
||||
|
||||
await manager.stop_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" not in manager.flows
|
||||
manager.flows[("default", "flow1")] = flow_data
|
||||
|
||||
await manager.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
assert ("default", "flow1") not in manager.flows
|
||||
|
||||
def test_dispatch_global_service_returns_wrapper(self):
|
||||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
|
|
@ -275,12 +275,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
|
|
@ -290,7 +290,7 @@ class TestDispatcherManager:
|
|||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
|
||||
params = {"flow": "test_flow", "kind": "triples"}
|
||||
result = await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
|
|
@ -326,12 +326,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
|
||||
mock_dispatchers.__contains__.return_value = False
|
||||
|
||||
|
|
@ -348,12 +348,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
|
|
@ -404,7 +404,7 @@ class TestDispatcherManager:
|
|||
params = {"flow": "test_flow", "kind": "agent"}
|
||||
result = await manager.process_flow_service("data", "responder", params)
|
||||
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
|
||||
assert result == "flow_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -415,14 +415,14 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
||||
|
||||
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
assert result == "cached_result"
|
||||
|
|
@ -435,7 +435,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
|
|
@ -443,7 +443,7 @@ class TestDispatcherManager:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
|
|
@ -452,23 +452,23 @@ class TestDispatcherManager:
|
|||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
request_queue="agent_request_queue",
|
||||
response_queue="agent_response_queue",
|
||||
timeout=120,
|
||||
consumer="api-gateway-test_flow-agent-request",
|
||||
subscriber="api-gateway-test_flow-agent-request"
|
||||
consumer="api-gateway-default-test_flow-agent-request",
|
||||
subscriber="api-gateway-default-test_flow-agent-request"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
|
||||
assert result == "new_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -479,26 +479,26 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"text-load": {"flow": "text_load_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
|
||||
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
|
||||
mock_rr_dispatchers.__contains__.return_value = False
|
||||
mock_sender_dispatchers.__contains__.return_value = True
|
||||
|
||||
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="sender_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
|
||||
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
|
|
@ -506,9 +506,9 @@ class TestDispatcherManager:
|
|||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
|
||||
assert result == "sender_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -519,7 +519,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
||||
|
|
@ -529,14 +529,14 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"text-completion": {"request": "req", "response": "resp"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_kind(self):
|
||||
|
|
@ -546,7 +546,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"invalid-kind": {"request": "req", "response": "resp"}
|
||||
}
|
||||
|
|
@ -558,7 +558,7 @@ class TestDispatcherManager:
|
|||
mock_sender_dispatchers.__contains__.return_value = False
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
|
||||
|
|
@ -608,7 +608,7 @@ class TestDispatcherManager:
|
|||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
|
|
@ -630,7 +630,7 @@ class TestDispatcherManager:
|
|||
mock_rr_dispatchers.__contains__.return_value = True
|
||||
|
||||
results = await asyncio.gather(*[
|
||||
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
for _ in range(5)
|
||||
])
|
||||
|
||||
|
|
@ -638,5 +638,5 @@ class TestDispatcherManager:
|
|||
"Dispatcher class instantiated more than once — duplicate consumer bug"
|
||||
)
|
||||
assert mock_dispatcher.start.call_count == 1
|
||||
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
|
||||
assert all(r == "result" for r in results)
|
||||
|
|
@ -239,7 +239,7 @@ def _make_processor(tools=None):
|
|||
agent = MagicMock()
|
||||
agent.tools = tools or {}
|
||||
agent.additional_context = ""
|
||||
processor.agent = agent
|
||||
processor.agents = {"default": agent}
|
||||
processor.aggregator = MagicMock()
|
||||
|
||||
return processor
|
||||
|
|
@ -254,6 +254,7 @@ def _make_flow():
|
|||
return producers[name]
|
||||
|
||||
flow = MagicMock(side_effect=factory)
|
||||
flow.workspace = "default"
|
||||
return flow
|
||||
|
||||
|
||||
|
|
@ -299,7 +300,7 @@ class TestAgentReactDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
@ -433,7 +434,7 @@ class TestAgentPlanDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
@ -537,7 +538,7 @@ class TestAgentSupervisorDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
|
|||
|
|
@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
|
|||
"""Test parsing of schema configuration"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.schema_builders = {}
|
||||
processor.graphql_schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.schema_builder = MagicMock()
|
||||
processor.schema_builder.clear = MagicMock()
|
||||
processor.schema_builder.add_schema = MagicMock()
|
||||
processor.schema_builder.build = MagicMock(return_value=MagicMock())
|
||||
processor.query_cassandra = MagicMock()
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test config
|
||||
|
|
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
|
|||
}
|
||||
|
||||
# Process config
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
await processor.on_schema_config("default", schema_config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer" in processor.schemas
|
||||
schema = processor.schemas["customer"]
|
||||
assert "customer" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["customer"]
|
||||
assert schema.name == "customer"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
|
@ -147,24 +146,26 @@ class TestRowsGraphQLQueryLogic:
|
|||
status_field = next(f for f in schema.fields if f.name == "status")
|
||||
assert status_field.enum_values == ["active", "inactive"]
|
||||
|
||||
# Verify schema builder was called
|
||||
processor.schema_builder.add_schema.assert_called_once()
|
||||
processor.schema_builder.build.assert_called_once()
|
||||
# Verify per-workspace schema builder was created and graphql schema built
|
||||
assert "default" in processor.schema_builders
|
||||
assert "default" in processor.graphql_schemas
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_context_handling(self):
|
||||
"""Test GraphQL execution context setup"""
|
||||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
graphql_schema = AsyncMock()
|
||||
processor.graphql_schemas = {"default": graphql_schema}
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Mock schema execution
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
||||
mock_result.errors = None
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
graphql_schema.execute.return_value = mock_result
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
workspace="default",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
|
|
@ -173,8 +174,8 @@ class TestRowsGraphQLQueryLogic:
|
|||
)
|
||||
|
||||
# Verify schema.execute was called with correct context
|
||||
processor.graphql_schema.execute.assert_called_once()
|
||||
call_args = processor.graphql_schema.execute.call_args
|
||||
graphql_schema.execute.assert_called_once()
|
||||
call_args = graphql_schema.execute.call_args
|
||||
|
||||
# Verify context was passed
|
||||
context = call_args[1]['context_value']
|
||||
|
|
@ -190,7 +191,8 @@ class TestRowsGraphQLQueryLogic:
|
|||
async def test_error_handling_graphql_errors(self):
|
||||
"""Test GraphQL error handling and conversion"""
|
||||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
graphql_schema = AsyncMock()
|
||||
processor.graphql_schemas = {"default": graphql_schema}
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Create a simple object to simulate GraphQL error
|
||||
|
|
@ -212,9 +214,10 @@ class TestRowsGraphQLQueryLogic:
|
|||
mock_result = MagicMock()
|
||||
mock_result.data = None
|
||||
mock_result.errors = [mock_error]
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
graphql_schema.execute.return_value = mock_result
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
workspace="default",
|
||||
query='{ customers { invalid_field } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
|
|
@ -259,6 +262,7 @@ class TestRowsGraphQLQueryLogic:
|
|||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "default"
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
|
@ -267,6 +271,7 @@ class TestRowsGraphQLQueryLogic:
|
|||
|
||||
# Verify query was executed
|
||||
processor.execute_graphql_query.assert_called_once_with(
|
||||
workspace="default",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
|
|
|
|||
|
|
@ -286,11 +286,11 @@ class TestNLPQueryProcessor:
|
|||
}
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(config, "v1")
|
||||
await processor.on_schema_config("default", config, "v1")
|
||||
|
||||
# Assert
|
||||
assert "test_schema" in processor.schemas
|
||||
schema = processor.schemas["test_schema"]
|
||||
assert "test_schema" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["test_schema"]
|
||||
assert schema.name == "test_schema"
|
||||
assert schema.description == "Test schema"
|
||||
assert len(schema.fields) == 2
|
||||
|
|
@ -308,10 +308,10 @@ class TestNLPQueryProcessor:
|
|||
}
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(config, "v1")
|
||||
await processor.on_schema_config("default", config, "v1")
|
||||
|
||||
# Assert - bad schema should be ignored
|
||||
assert "bad_schema" not in processor.schemas
|
||||
assert "bad_schema" not in processor.schemas.get("default", {})
|
||||
|
||||
def test_processor_initialization(self, mock_pulsar_client):
|
||||
"""Test processor initialization with correct specifications"""
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ def service(mock_schemas):
|
|||
taskgroup=MagicMock(),
|
||||
id="test-processor"
|
||||
)
|
||||
service.schemas = mock_schemas
|
||||
service.schemas = {"default": dict(mock_schemas)}
|
||||
return service
|
||||
|
||||
|
||||
|
|
@ -109,6 +109,7 @@ def service(mock_schemas):
|
|||
def mock_flow():
|
||||
"""Create mock flow with prompt service"""
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
prompt_request_flow = AsyncMock()
|
||||
flow.return_value.request = prompt_request_flow
|
||||
return flow, prompt_request_flow
|
||||
|
|
|
|||
|
|
@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
|
|||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
|
||||
|
||||
class _MockFlowDefault:
|
||||
"""Mock Flow with default workspace for testing."""
|
||||
workspace = "default"
|
||||
name = "default"
|
||||
id = "test-processor"
|
||||
|
||||
|
||||
mock_flow_default = _MockFlowDefault()
|
||||
|
||||
class TestRowsCassandraStorageLogic:
|
||||
"""Test business logic for unified table implementation"""
|
||||
|
||||
|
|
@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic:
|
|||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert "customer_records" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
|
@ -165,14 +176,16 @@ class TestRowsCassandraStorageLogic:
|
|||
"""Test that row processing stores data as map<text, text>"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
|
|
@ -205,7 +218,7 @@ class TestRowsCassandraStorageLogic:
|
|||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify insert was executed
|
||||
mock_async_execute.assert_called()
|
||||
|
|
@ -230,14 +243,16 @@ class TestRowsCassandraStorageLogic:
|
|||
"""Test that row is written once per indexed field"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
|
|
@ -267,7 +282,7 @@ class TestRowsCassandraStorageLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 inserts (one per indexed field: id, category, status)
|
||||
assert mock_async_execute.call_count == 3
|
||||
|
|
@ -290,13 +305,15 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
|
|
@ -331,7 +348,7 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 inserts (one per row, one index per row since only primary key)
|
||||
assert mock_async_execute.call_count == 3
|
||||
|
|
@ -349,10 +366,12 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
"default": {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
|
|
@ -381,7 +400,7 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
@ -446,19 +465,21 @@ class TestPartitionRegistration:
|
|||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
|
||||
|
||||
# Should have 2 inserts (one per index: id, category)
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
|
@ -473,7 +494,7 @@ class TestPartitionRegistration:
|
|||
processor.session = MagicMock()
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
|
||||
|
||||
# Should not execute any CQL since already registered
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue