feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)

Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.

Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
  proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
  captures the workspace/collection/flow hierarchy.

Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
  DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
  Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
  service layer.
- Translators updated to not serialise/deserialise user.

API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.

Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
  scoped by workspace. Config client API takes workspace as first
  positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
  no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.

CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
  library) drop user kwargs from every method signature.

MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
  keyed per user.

Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
  whose blueprint template was parameterised AND no remaining
  live flow (across all workspaces) still resolves to that topic.
  Three scopes fall out naturally from template analysis:
    * {id} -> per-flow, deleted on stop
    * {blueprint} -> per-blueprint, kept while any flow of the
      same blueprint exists
    * {workspace} -> per-workspace, kept while any flow in the
      workspace exists
    * literal -> global, never deleted (e.g. tg.request.librarian)
  Fixes a bug where stopping a flow silently destroyed the global
  librarian exchange, wedging all library operations until manual
  restart.

RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
  dead connections (broker restart, orphaned channels, network
  partitions) within ~2 heartbeat windows, so the consumer
  reconnects and re-binds its queue rather than sitting forever
  on a zombie connection.

Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
  ~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
cybermaggedon 2026-04-21 23:23:01 +01:00 committed by GitHub
parent 9332089b3d
commit d35473f7f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
377 changed files with 6868 additions and 5785 deletions

View file

@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
"""Test basic agent integration with structured query tool"""
# Arrange - Load tool configuration
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Create agent request
request = AgentRequest(
@ -66,7 +66,6 @@ class TestAgentStructuredQueryIntegration:
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -119,6 +118,7 @@ Args: {
# Mock flow parameter in agent_processor.on_request
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -146,14 +146,13 @@ Args: {
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
"""Test agent handling of structured query errors"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="Find data from a table that doesn't exist using structured query.",
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -199,6 +198,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -221,14 +221,13 @@ Args: {
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
"""Test agent using structured query in multi-step reasoning"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="First find all customers from California, then tell me how many orders they have made.",
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -279,6 +278,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -313,14 +313,13 @@ Args: {
}
}
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
request = AgentRequest(
question="Query the sales data for recent transactions.",
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -371,6 +370,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)
@ -394,10 +394,10 @@ Args: {
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
"""Test that structured query tool arguments are properly validated"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
# Check that the tool was registered with correct arguments
tools = agent_processor.agent.tools
tools = agent_processor.agents["default"].tools
assert "structured-query" in tools
structured_tool = tools["structured-query"]
@ -414,14 +414,13 @@ Args: {
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
"""Test that structured query results are properly formatted for agent consumption"""
# Arrange
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
request = AgentRequest(
question="Get customer information and format it nicely.",
state="",
group=None,
history=[],
user="test_user"
)
msg = MagicMock()
@ -482,6 +481,7 @@ Args: {
flow = MagicMock()
flow.side_effect = flow_context
flow.workspace = "default"
# Act
await agent_processor.on_request(msg, consumer, flow)

View file

@ -40,14 +40,13 @@ class TestEndToEndConfigurationFlow:
# Create a mock message to trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
# This should create TrustGraph with environment config
await processor.store_triples(mock_message)
await processor.store_triples('test_user', mock_message)
# Verify Cluster was created with correct hosts
mock_cluster.assert_called_once()
@ -144,13 +143,12 @@ class TestConfigurationPriorityEndToEnd:
# Trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('test_user', mock_message)
# Should use CLI parameters, not environment
mock_cluster.assert_called_once()
@ -201,7 +199,6 @@ class TestConfigurationPriorityEndToEnd:
# Mock query to trigger TrustGraph creation
mock_query = MagicMock()
mock_query.user = 'default_user'
mock_query.collection = 'default_collection'
mock_query.s = None
mock_query.p = None
@ -213,7 +210,7 @@ class TestConfigurationPriorityEndToEnd:
mock_tg_instance.get_all.return_value = []
processor.tg = mock_tg_instance
await processor.query_triples(mock_query)
await processor.query_triples('default_user', mock_query)
# Should use defaults
mock_cluster.assert_called_once()
@ -244,13 +241,12 @@ class TestNoBackwardCompatibilityEndToEnd:
# Trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'legacy_user'
mock_message.metadata.collection = 'legacy_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('legacy_user', mock_message)
# Should use defaults since old parameters are not recognized
mock_cluster.assert_called_once()
@ -302,13 +298,12 @@ class TestNoBackwardCompatibilityEndToEnd:
# Trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'precedence_user'
mock_message.metadata.collection = 'precedence_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('precedence_user', mock_message)
# Should use new parameters, not old ones
mock_cluster.assert_called_once()
@ -354,13 +349,12 @@ class TestMultipleHostsHandling:
# Trigger TrustGraph creation
mock_message = MagicMock()
mock_message.metadata.user = 'single_user'
mock_message.metadata.collection = 'single_collection'
mock_message.triples = []
# Mock collection_exists to return True
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('single_user', mock_message)
# Single host should be converted to list
mock_cluster.assert_called_once()

View file

@ -115,7 +115,7 @@ class TestCassandraIntegration:
# Create test message
storage_message = Triples(
metadata=Metadata(user="testuser", collection="testcol"),
metadata=Metadata(collection="testcol"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.org/person1"),
@ -178,7 +178,7 @@ class TestCassandraIntegration:
# Store test data for querying
query_test_message = Triples(
metadata=Metadata(user="testuser", collection="testcol"),
metadata=Metadata(collection="testcol"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.org/alice"),
@ -212,7 +212,6 @@ class TestCassandraIntegration:
p=None, # None for wildcard
o=None, # None for wildcard
limit=10,
user="testuser",
collection="testcol"
)
s_results = await query_processor.query_triples(s_query)
@ -232,7 +231,6 @@ class TestCassandraIntegration:
p=Term(type=IRI, iri="http://example.org/knows"),
o=None, # None for wildcard
limit=10,
user="testuser",
collection="testcol"
)
p_results = await query_processor.query_triples(p_query)
@ -259,7 +257,7 @@ class TestCassandraIntegration:
# Create multiple coroutines for concurrent storage
async def store_person_data(person_id, name, age, department):
message = Triples(
metadata=Metadata(user="concurrent_test", collection="people"),
metadata=Metadata(collection="people"),
triples=[
Triple(
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
@ -329,7 +327,7 @@ class TestCassandraIntegration:
# Create a knowledge graph about a company
company_graph = Triples(
metadata=Metadata(user="integration_test", collection="company"),
metadata=Metadata(collection="company"),
triples=[
# People and their types
Triple(

View file

@ -99,7 +99,6 @@ class TestDocumentRagIntegration:
# Act
result = await document_rag.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit
)
@ -110,7 +109,6 @@ class TestDocumentRagIntegration:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
user=user,
collection=collection
)
@ -278,14 +276,12 @@ class TestDocumentRagIntegration:
# Act
await document_rag.query(
f"query from {user} in {collection}",
user=user,
collection=collection
)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio
@ -353,6 +349,5 @@ class TestDocumentRagIntegration:
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == "trustgraph"
assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20

View file

@ -107,7 +107,6 @@ class TestDocumentRagStreaming:
# Act
result = await document_rag_streaming.query(
query=query,
user="test_user",
collection="test_collection",
doc_limit=10,
streaming=True,
@ -141,7 +140,6 @@ class TestDocumentRagStreaming:
# Act - Non-streaming
non_streaming_result = await document_rag_streaming.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit,
streaming=False
@ -155,7 +153,6 @@ class TestDocumentRagStreaming:
streaming_result = await document_rag_streaming.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit,
streaming=True,
@ -178,7 +175,6 @@ class TestDocumentRagStreaming:
# Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
@ -200,7 +196,6 @@ class TestDocumentRagStreaming:
# Arrange & Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
@ -223,7 +218,6 @@ class TestDocumentRagStreaming:
# Act
result = await document_rag_streaming.query(
query="unknown topic",
user="test_user",
collection="test_collection",
doc_limit=10,
streaming=True,
@ -247,7 +241,6 @@ class TestDocumentRagStreaming:
with pytest.raises(Exception) as exc_info:
await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
@ -272,7 +265,6 @@ class TestDocumentRagStreaming:
# Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=limit,
streaming=True,
@ -300,7 +292,6 @@ class TestDocumentRagStreaming:
# Act
await document_rag_streaming.query(
query="test query",
user=user,
collection=collection,
doc_limit=10,
streaming=True,
@ -309,5 +300,4 @@ class TestDocumentRagStreaming:
# Assert - Verify user/collection were passed to document embeddings client
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection

View file

@ -146,7 +146,6 @@ class TestGraphRagIntegration:
# Act
response = await graph_rag.query(
query=query,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit,
@ -163,7 +162,6 @@ class TestGraphRagIntegration:
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
# 3. Should query triples to build knowledge subgraph
@ -204,7 +202,6 @@ class TestGraphRagIntegration:
# Act
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection",
entity_limit=config["entity_limit"],
triple_limit=config["triple_limit"]
@ -224,7 +221,6 @@ class TestGraphRagIntegration:
with pytest.raises(Exception) as exc_info:
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection"
)
@ -247,7 +243,6 @@ class TestGraphRagIntegration:
# Act
response = await graph_rag.query(
query="unknown topic",
user="test_user",
collection="test_collection",
explain_callback=collect_provenance
)
@ -267,7 +262,6 @@ class TestGraphRagIntegration:
# First query
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection"
)
@ -277,7 +271,6 @@ class TestGraphRagIntegration:
# Second identical query
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection"
)
@ -289,26 +282,27 @@ class TestGraphRagIntegration:
assert second_call_count >= 0 # Should complete without errors
@pytest.mark.asyncio
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client):
"""Test that different users/collections are properly isolated"""
async def test_graph_rag_multi_collection_isolation(self, graph_rag, mock_graph_embeddings_client):
"""Test that different collections propagate through to the embeddings query.
Workspace isolation is enforced by flow.workspace at the service
boundary not by parameters on GraphRag.query so this test
verifies collection routing only.
"""
# Arrange
query = "test query"
user1, collection1 = "user1", "collection1"
user2, collection2 = "user2", "collection2"
collection1 = "collection1"
collection2 = "collection2"
# Act
await graph_rag.query(query=query, user=user1, collection=collection1)
await graph_rag.query(query=query, user=user2, collection=collection2)
await graph_rag.query(query=query, collection=collection1)
await graph_rag.query(query=query, collection=collection2)
# Assert - Both users should have separate queries
# Assert - Each call propagated its collection
assert mock_graph_embeddings_client.query.call_count == 2
# Verify first call
first_call = mock_graph_embeddings_client.query.call_args_list[0]
assert first_call.kwargs['user'] == user1
assert first_call.kwargs['collection'] == collection1
# Verify second call
second_call = mock_graph_embeddings_client.query.call_args_list[1]
assert second_call.kwargs['user'] == user2
assert second_call.kwargs['collection'] == collection2

View file

@ -116,7 +116,6 @@ class TestGraphRagStreaming:
# Act - query() returns response, provenance via callback
response = await graph_rag_streaming.query(
query=query,
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collector.collect,
@ -154,7 +153,6 @@ class TestGraphRagStreaming:
# Act - Non-streaming
non_streaming_response = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
streaming=False
)
@ -167,7 +165,6 @@ class TestGraphRagStreaming:
streaming_response = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
streaming=True,
chunk_callback=collect
@ -189,7 +186,6 @@ class TestGraphRagStreaming:
# Act
response = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -209,7 +205,6 @@ class TestGraphRagStreaming:
# Arrange & Act
response = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=None # No callback provided
@ -231,7 +226,6 @@ class TestGraphRagStreaming:
# Act
response = await graph_rag_streaming.query(
query="unknown topic",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -253,7 +247,6 @@ class TestGraphRagStreaming:
with pytest.raises(Exception) as exc_info:
await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -273,7 +266,6 @@ class TestGraphRagStreaming:
# Act
await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
entity_limit=entity_limit,
triple_limit=triple_limit,

View file

@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend):
triples_obj = Triples(
metadata=Metadata(
id=f"export-msg-{i}",
user=msg_data["metadata"]["user"],
collection=msg_data["metadata"]["collection"],
),
triples=to_subgraph(msg_data["triples"]),

View file

@ -97,7 +97,6 @@ class TestKnowledgeGraphPipelineIntegration:
return Chunk(
metadata=Metadata(
id="doc-123",
user="test_user",
collection="test_collection",
),
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
@ -247,7 +246,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
metadata = Metadata(
id="test-doc",
user="test_user",
collection="test_collection",
)
@ -305,7 +303,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
metadata = Metadata(
id="test-doc",
user="test_user",
collection="test_collection",
)
@ -375,7 +372,6 @@ class TestKnowledgeGraphPipelineIntegration:
sample_triples = Triples(
metadata=Metadata(
id="test-doc",
user="test_user",
collection="test_collection",
),
triples=[
@ -390,11 +386,14 @@ class TestKnowledgeGraphPipelineIntegration:
mock_msg = MagicMock()
mock_msg.value.return_value = sample_triples
mock_flow = MagicMock()
mock_flow.workspace = "test_workspace"
# Act
await processor.on_triples(mock_msg, None, None)
await processor.on_triples(mock_msg, None, mock_flow)
# Assert
mock_cassandra_store.add_triples.assert_called_once_with(sample_triples)
mock_cassandra_store.add_triples.assert_called_once_with("test_workspace", sample_triples)
@pytest.mark.asyncio
async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store):
@ -407,7 +406,6 @@ class TestKnowledgeGraphPipelineIntegration:
sample_embeddings = GraphEmbeddings(
metadata=Metadata(
id="test-doc",
user="test_user",
collection="test_collection",
),
entities=[
@ -421,11 +419,14 @@ class TestKnowledgeGraphPipelineIntegration:
mock_msg = MagicMock()
mock_msg.value.return_value = sample_embeddings
mock_flow = MagicMock()
mock_flow.workspace = "test_workspace"
# Act
await processor.on_graph_embeddings(mock_msg, None, None)
await processor.on_graph_embeddings(mock_msg, None, mock_flow)
# Assert
mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings)
mock_cassandra_store.add_graph_embeddings.assert_called_once_with("test_workspace", sample_embeddings)
@pytest.mark.asyncio
async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor,
@ -553,7 +554,7 @@ class TestKnowledgeGraphPipelineIntegration:
)
sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection"),
metadata=Metadata(id="test", collection="collection"),
chunk=b"Test chunk"
)
@ -580,7 +581,7 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
large_chunk_batch = [
Chunk(
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"),
metadata=Metadata(id=f"doc-{i}", collection="collection"),
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
)
for i in range(100) # Large batch
@ -617,7 +618,6 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
original_metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -646,9 +646,7 @@ class TestKnowledgeGraphPipelineIntegration:
entity_contexts_call = entity_contexts_producer.send.call_args[0][0]
assert triples_call.metadata.id == "test-doc-123"
assert triples_call.metadata.user == "test_user"
assert triples_call.metadata.collection == "test_collection"
assert entity_contexts_call.metadata.id == "test-doc-123"
assert entity_contexts_call.metadata.user == "test_user"
assert entity_contexts_call.metadata.collection == "test_collection"

View file

@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
)
# Set up schemas
proc.schemas = sample_schemas
proc.schemas = {"default": dict(sample_schemas)}
# Mock the client method
proc.client = MagicMock()
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
}
# Act - Update configuration
await integration_processor.on_schema_config(new_schema_config, "v2")
await integration_processor.on_schema_config("default", new_schema_config, "v2")
# Arrange - Test query using new schema
request = QuestionToStructuredQueryRequest(
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
await integration_processor.on_message(msg, consumer, flow)
# Assert
assert "inventory" in integration_processor.schemas
assert "inventory" in integration_processor.schemas["default"]
response_call = flow_response.send.call_args
response = response_call[0][0]
assert response.detected_schemas == ["inventory"]
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
graphql_generation_template="custom-graphql-generator"
)
custom_processor.schemas = sample_schemas
custom_processor.schemas = {"default": dict(sample_schemas)}
custom_processor.client = MagicMock()
request = QuestionToStructuredQueryRequest(
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
)
integration_processor.schemas.update(large_schema_set)
integration_processor.schemas["default"].update(large_schema_set)
request = QuestionToStructuredQueryRequest(
question="Show me data from table_05 and table_12",
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
flow_response = AsyncMock()
flow.return_value = flow_response

View file

@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
return context
@pytest.mark.asyncio
@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Act
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Assert
assert len(processor.schemas) == 2
assert "customer_records" in processor.schemas
assert "product_catalog" in processor.schemas
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 2
assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
# Verify customer schema
customer_schema = processor.schemas["customer_records"]
customer_schema = ws_schemas["customer_records"]
assert customer_schema.name == "customer_records"
assert len(customer_schema.fields) == 4
# Verify product schema
product_schema = processor.schemas["product_catalog"]
product_schema = ws_schemas["product_catalog"]
assert product_schema.name == "product_catalog"
assert len(product_schema.fields) == 4
@ -237,12 +239,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create realistic customer data chunk
metadata = Metadata(
id="customer-doc-001",
user="integration_test",
collection="test_documents",
)
@ -304,12 +305,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create realistic product data chunk
metadata = Metadata(
id="product-doc-001",
user="integration_test",
collection="test_documents",
)
@ -368,7 +368,7 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create multiple test chunks
chunks_data = [
@ -382,7 +382,6 @@ class TestObjectExtractionServiceIntegration:
for chunk_id, text in chunks_data:
metadata = Metadata(
id=chunk_id,
user="concurrent_test",
collection="test_collection",
)
chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8'))
@ -431,19 +430,21 @@ class TestObjectExtractionServiceIntegration:
"customer_records": integration_config["schema"]["customer_records"]
}
}
await processor.on_schema_config(initial_config, version=1)
assert len(processor.schemas) == 1
assert "customer_records" in processor.schemas
assert "product_catalog" not in processor.schemas
await processor.on_schema_config("default", initial_config, version=1)
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 1
assert "customer_records" in ws_schemas
assert "product_catalog" not in ws_schemas
# Act - Reload with full configuration
await processor.on_schema_config(integration_config, version=2)
await processor.on_schema_config("default", integration_config, version=2)
# Assert
assert len(processor.schemas) == 2
assert "customer_records" in processor.schemas
assert "product_catalog" in processor.schemas
ws_schemas = processor.schemas["default"]
assert len(ws_schemas) == 2
assert "customer_records" in ws_schemas
assert "product_catalog" in ws_schemas
@pytest.mark.asyncio
async def test_error_resilience_integration(self, integration_config):
@ -474,13 +475,14 @@ class TestObjectExtractionServiceIntegration:
return AsyncMock()
failing_flow.side_effect = failing_context_router
failing_flow.workspace = "default"
processor.flow = failing_flow
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create test chunk
metadata = Metadata(id="error-test", user="test", collection="test")
metadata = Metadata(id="error-test", collection="test")
chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process")
mock_msg = MagicMock()
@ -510,12 +512,11 @@ class TestObjectExtractionServiceIntegration:
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
await processor.on_schema_config(integration_config, version=1)
await processor.on_schema_config("default", integration_config, version=1)
# Create chunk with rich metadata
original_metadata = Metadata(
id="metadata-test-chunk",
user="test_user",
collection="test_collection",
)
@ -544,6 +545,5 @@ class TestObjectExtractionServiceIntegration:
assert extracted_obj is not None
# Verify metadata propagation
assert extracted_obj.metadata.user == "test_user"
assert extracted_obj.metadata.collection == "test_collection"
assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference

View file

@ -87,6 +87,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
return context
@pytest.fixture
@ -109,7 +110,7 @@ class TestPromptStreaming:
def prompt_processor_streaming(self, mock_prompt_manager):
"""Create Prompt processor with streaming support"""
processor = MagicMock()
processor.manager = mock_prompt_manager
processor.managers = {"default": mock_prompt_manager}
processor.config_key = "prompt"
# Bind the actual on_request method
@ -248,6 +249,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
request = PromptRequest(
id="test_prompt",
@ -341,6 +343,7 @@ class TestPromptStreaming:
return AsyncMock()
context.side_effect = context_router
context.workspace = "default"
request = PromptRequest(
id="test_prompt",

View file

@ -84,7 +84,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -108,7 +107,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -137,7 +135,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -162,7 +159,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -188,7 +184,6 @@ class TestGraphRagStreamingProtocol:
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -267,7 +262,6 @@ class TestDocumentRagStreamingProtocol:
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
@ -290,7 +284,6 @@ class TestDocumentRagStreamingProtocol:
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
@ -314,7 +307,6 @@ class TestDocumentRagStreamingProtocol:
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect

View file

@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
@pytest.mark.integration
class TestRowsCassandraIntegration:
"""Integration tests for Cassandra row storage with unified table"""
@ -125,14 +136,13 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
await processor.on_schema_config("default", config, version=1)
assert "customer_records" in processor.schemas["default"]
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
id="doc-001",
user="test_user",
collection="import_2024",
),
schema_name="customer_records",
@ -149,7 +159,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify Cassandra interactions
assert mock_cluster.connect.called
@ -158,7 +168,7 @@ class TestRowsCassandraIntegration:
keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0])
assert "default" in str(keyspace_calls[0])
# Verify unified table creation (rows table, not per-schema table)
table_calls = [call for call in mock_session.execute.call_args_list
@ -209,12 +219,12 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
await processor.on_schema_config("default", config, version=1)
assert len(processor.schemas["default"]) == 2
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog"),
metadata=Metadata(id="p1", collection="catalog"),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
@ -222,7 +232,7 @@ class TestRowsCassandraIntegration:
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales"),
metadata=Metadata(id="o1", collection="sales"),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
@ -233,7 +243,7 @@ class TestRowsCassandraIntegration:
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# All data goes into the same unified rows table
table_calls = [call for call in mock_session.execute.call_args_list
@ -256,18 +266,20 @@ class TestRowsCassandraIntegration:
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Schema with multiple indexed fields
processor.schemas["indexed_data"] = RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
processor.schemas["default"] = {
"indexed_data": RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"),
metadata=Metadata(id="t1", collection="test"),
schema_name="indexed_data",
values=[{
"id": "123",
@ -282,7 +294,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 data inserts (one per indexed field: id, category, status)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -342,13 +354,12 @@ class TestRowsCassandraIntegration:
}
}
await processor.on_schema_config(config, version=1)
await processor.on_schema_config("default", config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
),
schema_name="batch_customers",
@ -376,7 +387,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify unified table creation
table_calls = [call for call in mock_session.execute.call_args_list
@ -396,14 +407,16 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
processor.schemas["default"] = {
"empty_test": RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty"),
metadata=Metadata(id="empty-1", collection="empty"),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
@ -413,7 +426,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should not create any data insert statements for empty batch
# (partition registration may still happen)
@ -428,17 +441,19 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["map_test"] = RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
processor.schemas["default"] = {
"map_test": RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test"),
metadata=Metadata(id="t1", collection="test"),
schema_name="map_test",
values=[{"id": "123", "name": "Test Item", "count": "42"}],
confidence=0.9,
@ -448,7 +463,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify insert uses map for data
rows_insert_calls = [call for call in mock_session.execute.call_args_list
@ -473,16 +488,18 @@ class TestRowsCassandraIntegration:
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["partition_test"] = RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
processor.schemas["default"] = {
"partition_test": RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
}
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection"),
metadata=Metadata(id="t1", collection="my_collection"),
schema_name="partition_test",
values=[{"id": "123", "category": "test"}],
confidence=0.9,
@ -492,7 +509,7 @@ class TestRowsCassandraIntegration:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify partition registration
partition_inserts = [call for call in mock_session.execute.call_args_list

View file

@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
"""Test schema configuration loading and GraphQL schema generation"""
# Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Verify schemas were loaded
assert len(processor.schemas) == 2
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
"""Test Cassandra connection and dynamic table creation"""
# Load schema configuration
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Connect to Cassandra
processor.connect_cassandra()
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
"""Test inserting data and querying via GraphQL"""
# Load schema and connect
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Setup test data
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
"""Test GraphQL queries with filtering on indexed fields"""
# Setup (reuse previous setup)
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
keyspace = "test_user"
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_graphql_error_handling(self, processor, sample_schema_config):
"""Test GraphQL error handling for invalid queries"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
# Test invalid field query
invalid_query = '''
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_message_processing_integration(self, processor, sample_schema_config):
"""Test full message processing workflow"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Create mock message
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_concurrent_queries(self, processor, sample_schema_config):
"""Test handling multiple concurrent GraphQL queries"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
# Create multiple query tasks
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
}
}
await processor.on_schema_config(initial_config, version=1)
await processor.on_schema_config("default", initial_config, version=1)
assert len(processor.schemas) == 1
assert "simple" in processor.schemas
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
}
}
await processor.on_schema_config(updated_config, version=2)
await processor.on_schema_config("default", updated_config, version=2)
# Verify updated schemas
assert len(processor.schemas) == 2
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
async def test_large_result_set_handling(self, processor, sample_schema_config):
"""Test handling of large query result sets"""
# Setup
await processor.on_schema_config(sample_schema_config, version=1)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
keyspace = "large_test_user"
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
}
}
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Measure query execution time
start_time = time.time()

View file

@ -42,7 +42,6 @@ class TestStructuredQueryServiceIntegration:
# Arrange - Create realistic query request
request = StructuredQueryRequest(
question="Show me all customers from California who have made purchases over $500",
user="trustgraph",
collection="default"
)
@ -126,7 +125,6 @@ class TestStructuredQueryServiceIntegration:
assert "orders" in objects_call_args.query
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
assert objects_call_args.variables["state"] == "California"
assert objects_call_args.user == "trustgraph"
assert objects_call_args.collection == "default"
# Verify response