mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-17 03:15:14 +02:00
feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)
Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.
Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
captures the workspace/collection/flow hierarchy.
Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
service layer.
- Translators updated to not serialise/deserialise user.
API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.
Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
scoped by workspace. Config client API takes workspace as first
positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.
CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
library) drop user kwargs from every method signature.
MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
keyed per user.
Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
whose blueprint template was parameterised AND no remaining
live flow (across all workspaces) still resolves to that topic.
Three scopes fall out naturally from template analysis:
* {id} -> per-flow, deleted on stop
* {blueprint} -> per-blueprint, kept while any flow of the
same blueprint exists
* {workspace} -> per-workspace, kept while any flow in the
workspace exists
* literal -> global, never deleted (e.g. tg.request.librarian)
Fixes a bug where stopping a flow silently destroyed the global
librarian exchange, wedging all library operations until manual
restart.
RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
dead connections (broker restart, orphaned channels, network
partitions) within ~2 heartbeat windows, so the consumer
reconnects and re-binds its queue rather than sitting forever
on a zombie connection.
Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
parent
9332089b3d
commit
d35473f7f7
377 changed files with 6868 additions and 5785 deletions
|
|
@ -58,7 +58,7 @@ class TestAgentStructuredQueryIntegration:
|
|||
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
|
||||
"""Test basic agent integration with structured query tool"""
|
||||
# Arrange - Load tool configuration
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
# Create agent request
|
||||
request = AgentRequest(
|
||||
|
|
@ -66,7 +66,6 @@ class TestAgentStructuredQueryIntegration:
|
|||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -119,6 +118,7 @@ Args: {
|
|||
# Mock flow parameter in agent_processor.on_request
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -146,14 +146,13 @@ Args: {
|
|||
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent handling of structured query errors"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Find data from a table that doesn't exist using structured query.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -199,6 +198,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -221,14 +221,13 @@ Args: {
|
|||
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent using structured query in multi-step reasoning"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="First find all customers from California, then tell me how many orders they have made.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -279,6 +278,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -313,14 +313,13 @@ Args: {
|
|||
}
|
||||
}
|
||||
|
||||
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
|
||||
await agent_processor.on_tools_config("default", tool_config_with_collection, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Query the sales data for recent transactions.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -371,6 +370,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
@ -394,10 +394,10 @@ Args: {
|
|||
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query tool arguments are properly validated"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
# Check that the tool was registered with correct arguments
|
||||
tools = agent_processor.agent.tools
|
||||
tools = agent_processor.agents["default"].tools
|
||||
assert "structured-query" in tools
|
||||
|
||||
structured_tool = tools["structured-query"]
|
||||
|
|
@ -414,14 +414,13 @@ Args: {
|
|||
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query results are properly formatted for agent consumption"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
await agent_processor.on_tools_config("default", structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Get customer information and format it nicely.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
|
|
@ -482,6 +481,7 @@ Args: {
|
|||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
flow.workspace = "default"
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
|
|
|||
|
|
@ -40,14 +40,13 @@ class TestEndToEndConfigurationFlow:
|
|||
|
||||
# Create a mock message to trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
# This should create TrustGraph with environment config
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('test_user', mock_message)
|
||||
|
||||
# Verify Cluster was created with correct hosts
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -144,13 +143,12 @@ class TestConfigurationPriorityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('test_user', mock_message)
|
||||
|
||||
# Should use CLI parameters, not environment
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -201,7 +199,6 @@ class TestConfigurationPriorityEndToEnd:
|
|||
|
||||
# Mock query to trigger TrustGraph creation
|
||||
mock_query = MagicMock()
|
||||
mock_query.user = 'default_user'
|
||||
mock_query.collection = 'default_collection'
|
||||
mock_query.s = None
|
||||
mock_query.p = None
|
||||
|
|
@ -213,7 +210,7 @@ class TestConfigurationPriorityEndToEnd:
|
|||
mock_tg_instance.get_all.return_value = []
|
||||
processor.tg = mock_tg_instance
|
||||
|
||||
await processor.query_triples(mock_query)
|
||||
await processor.query_triples('default_user', mock_query)
|
||||
|
||||
# Should use defaults
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -244,13 +241,12 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'legacy_user'
|
||||
mock_message.metadata.collection = 'legacy_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('legacy_user', mock_message)
|
||||
|
||||
# Should use defaults since old parameters are not recognized
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -302,13 +298,12 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'precedence_user'
|
||||
mock_message.metadata.collection = 'precedence_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('precedence_user', mock_message)
|
||||
|
||||
# Should use new parameters, not old ones
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
@ -354,13 +349,12 @@ class TestMultipleHostsHandling:
|
|||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'single_user'
|
||||
mock_message.metadata.collection = 'single_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# Mock collection_exists to return True
|
||||
with patch('trustgraph.direct.cassandra_kg.KnowledgeGraph.collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('single_user', mock_message)
|
||||
|
||||
# Single host should be converted to list
|
||||
mock_cluster.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Create test message
|
||||
storage_message = Triples(
|
||||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
metadata=Metadata(collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
|
|
@ -178,7 +178,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Store test data for querying
|
||||
query_test_message = Triples(
|
||||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
metadata=Metadata(collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
|
|
@ -212,7 +212,6 @@ class TestCassandraIntegration:
|
|||
p=None, # None for wildcard
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
collection="testcol"
|
||||
)
|
||||
s_results = await query_processor.query_triples(s_query)
|
||||
|
|
@ -232,7 +231,6 @@ class TestCassandraIntegration:
|
|||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
collection="testcol"
|
||||
)
|
||||
p_results = await query_processor.query_triples(p_query)
|
||||
|
|
@ -259,7 +257,7 @@ class TestCassandraIntegration:
|
|||
# Create multiple coroutines for concurrent storage
|
||||
async def store_person_data(person_id, name, age, department):
|
||||
message = Triples(
|
||||
metadata=Metadata(user="concurrent_test", collection="people"),
|
||||
metadata=Metadata(collection="people"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
|
|
@ -329,7 +327,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Create a knowledge graph about a company
|
||||
company_graph = Triples(
|
||||
metadata=Metadata(user="integration_test", collection="company"),
|
||||
metadata=Metadata(collection="company"),
|
||||
triples=[
|
||||
# People and their types
|
||||
Triple(
|
||||
|
|
|
|||
|
|
@ -99,7 +99,6 @@ class TestDocumentRagIntegration:
|
|||
# Act
|
||||
result = await document_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
|
@ -110,7 +109,6 @@ class TestDocumentRagIntegration:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
limit=doc_limit,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
|
|
@ -278,14 +276,12 @@ class TestDocumentRagIntegration:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
f"query from {user} in {collection}",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -353,6 +349,5 @@ class TestDocumentRagIntegration:
|
|||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == "trustgraph"
|
||||
assert call_args.kwargs['collection'] == "default"
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
|
|
|
|||
|
|
@ -107,7 +107,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -141,7 +140,6 @@ class TestDocumentRagStreaming:
|
|||
# Act - Non-streaming
|
||||
non_streaming_result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
streaming=False
|
||||
|
|
@ -155,7 +153,6 @@ class TestDocumentRagStreaming:
|
|||
|
||||
streaming_result = await document_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
streaming=True,
|
||||
|
|
@ -178,7 +175,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -200,7 +196,6 @@ class TestDocumentRagStreaming:
|
|||
# Arrange & Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -223,7 +218,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -247,7 +241,6 @@ class TestDocumentRagStreaming:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=5,
|
||||
streaming=True,
|
||||
|
|
@ -272,7 +265,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
result = await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=limit,
|
||||
streaming=True,
|
||||
|
|
@ -300,7 +292,6 @@ class TestDocumentRagStreaming:
|
|||
# Act
|
||||
await document_rag_streaming.query(
|
||||
query="test query",
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=10,
|
||||
streaming=True,
|
||||
|
|
@ -309,5 +300,4 @@ class TestDocumentRagStreaming:
|
|||
|
||||
# Assert - Verify user/collection were passed to document embeddings client
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
|
|
|||
|
|
@ -146,7 +146,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
response = await graph_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
|
|
@ -163,7 +162,6 @@ class TestGraphRagIntegration:
|
|||
call_args = mock_graph_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert call_args.kwargs['limit'] == entity_limit
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
# 3. Should query triples to build knowledge subgraph
|
||||
|
|
@ -204,7 +202,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=config["entity_limit"],
|
||||
triple_limit=config["triple_limit"]
|
||||
|
|
@ -224,7 +221,6 @@ class TestGraphRagIntegration:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -247,7 +243,6 @@ class TestGraphRagIntegration:
|
|||
# Act
|
||||
response = await graph_rag.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
|
@ -267,7 +262,6 @@ class TestGraphRagIntegration:
|
|||
# First query
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -277,7 +271,6 @@ class TestGraphRagIntegration:
|
|||
# Second identical query
|
||||
await graph_rag.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -289,26 +282,27 @@ class TestGraphRagIntegration:
|
|||
assert second_call_count >= 0 # Should complete without errors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client):
|
||||
"""Test that different users/collections are properly isolated"""
|
||||
async def test_graph_rag_multi_collection_isolation(self, graph_rag, mock_graph_embeddings_client):
|
||||
"""Test that different collections propagate through to the embeddings query.
|
||||
|
||||
Workspace isolation is enforced by flow.workspace at the service
|
||||
boundary — not by parameters on GraphRag.query — so this test
|
||||
verifies collection routing only.
|
||||
"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
user1, collection1 = "user1", "collection1"
|
||||
user2, collection2 = "user2", "collection2"
|
||||
collection1 = "collection1"
|
||||
collection2 = "collection2"
|
||||
|
||||
# Act
|
||||
await graph_rag.query(query=query, user=user1, collection=collection1)
|
||||
await graph_rag.query(query=query, user=user2, collection=collection2)
|
||||
await graph_rag.query(query=query, collection=collection1)
|
||||
await graph_rag.query(query=query, collection=collection2)
|
||||
|
||||
# Assert - Both users should have separate queries
|
||||
# Assert - Each call propagated its collection
|
||||
assert mock_graph_embeddings_client.query.call_count == 2
|
||||
|
||||
# Verify first call
|
||||
first_call = mock_graph_embeddings_client.query.call_args_list[0]
|
||||
assert first_call.kwargs['user'] == user1
|
||||
assert first_call.kwargs['collection'] == collection1
|
||||
|
||||
# Verify second call
|
||||
second_call = mock_graph_embeddings_client.query.call_args_list[1]
|
||||
assert second_call.kwargs['user'] == user2
|
||||
assert second_call.kwargs['collection'] == collection2
|
||||
|
|
|
|||
|
|
@ -116,7 +116,6 @@ class TestGraphRagStreaming:
|
|||
# Act - query() returns response, provenance via callback
|
||||
response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collector.collect,
|
||||
|
|
@ -154,7 +153,6 @@ class TestGraphRagStreaming:
|
|||
# Act - Non-streaming
|
||||
non_streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=False
|
||||
)
|
||||
|
|
@ -167,7 +165,6 @@ class TestGraphRagStreaming:
|
|||
|
||||
streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -189,7 +186,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -209,7 +205,6 @@ class TestGraphRagStreaming:
|
|||
# Arrange & Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=None # No callback provided
|
||||
|
|
@ -231,7 +226,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
response = await graph_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -253,7 +247,6 @@ class TestGraphRagStreaming:
|
|||
with pytest.raises(Exception) as exc_info:
|
||||
await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -273,7 +266,6 @@ class TestGraphRagStreaming:
|
|||
# Act
|
||||
await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
|
|
|
|||
|
|
@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend):
|
|||
triples_obj = Triples(
|
||||
metadata=Metadata(
|
||||
id=f"export-msg-{i}",
|
||||
user=msg_data["metadata"]["user"],
|
||||
collection=msg_data["metadata"]["collection"],
|
||||
),
|
||||
triples=to_subgraph(msg_data["triples"]),
|
||||
|
|
|
|||
|
|
@ -97,7 +97,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
return Chunk(
|
||||
metadata=Metadata(
|
||||
id="doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
|
||||
|
|
@ -247,7 +246,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -305,7 +303,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -375,7 +372,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
sample_triples = Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
triples=[
|
||||
|
|
@ -390,11 +386,14 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_triples
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "test_workspace"
|
||||
|
||||
# Act
|
||||
await processor.on_triples(mock_msg, None, None)
|
||||
await processor.on_triples(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_triples.assert_called_once_with(sample_triples)
|
||||
mock_cassandra_store.add_triples.assert_called_once_with("test_workspace", sample_triples)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store):
|
||||
|
|
@ -407,7 +406,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
sample_embeddings = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
entities=[
|
||||
|
|
@ -421,11 +419,14 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_embeddings
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "test_workspace"
|
||||
|
||||
# Act
|
||||
await processor.on_graph_embeddings(mock_msg, None, None)
|
||||
await processor.on_graph_embeddings(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings)
|
||||
mock_cassandra_store.add_graph_embeddings.assert_called_once_with("test_workspace", sample_embeddings)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor,
|
||||
|
|
@ -553,7 +554,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection"),
|
||||
metadata=Metadata(id="test", collection="collection"),
|
||||
chunk=b"Test chunk"
|
||||
)
|
||||
|
||||
|
|
@ -580,7 +581,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
large_chunk_batch = [
|
||||
Chunk(
|
||||
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"),
|
||||
metadata=Metadata(id=f"doc-{i}", collection="collection"),
|
||||
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
|
||||
)
|
||||
for i in range(100) # Large batch
|
||||
|
|
@ -617,7 +618,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
original_metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -646,9 +646,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
entity_contexts_call = entity_contexts_producer.send.call_args[0][0]
|
||||
|
||||
assert triples_call.metadata.id == "test-doc-123"
|
||||
assert triples_call.metadata.user == "test_user"
|
||||
assert triples_call.metadata.collection == "test_collection"
|
||||
|
||||
assert entity_contexts_call.metadata.id == "test-doc-123"
|
||||
assert entity_contexts_call.metadata.user == "test_user"
|
||||
assert entity_contexts_call.metadata.collection == "test_collection"
|
||||
|
|
@ -72,7 +72,7 @@ class TestNLPQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Set up schemas
|
||||
proc.schemas = sample_schemas
|
||||
proc.schemas = {"default": dict(sample_schemas)}
|
||||
|
||||
# Mock the client method
|
||||
proc.client = MagicMock()
|
||||
|
|
@ -94,6 +94,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -173,6 +174,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -229,7 +231,7 @@ class TestNLPQueryServiceIntegration:
|
|||
}
|
||||
|
||||
# Act - Update configuration
|
||||
await integration_processor.on_schema_config(new_schema_config, "v2")
|
||||
await integration_processor.on_schema_config("default", new_schema_config, "v2")
|
||||
|
||||
# Arrange - Test query using new schema
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
|
|
@ -243,6 +245,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -272,7 +275,7 @@ class TestNLPQueryServiceIntegration:
|
|||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
assert "inventory" in integration_processor.schemas
|
||||
assert "inventory" in integration_processor.schemas["default"]
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
assert response.detected_schemas == ["inventory"]
|
||||
|
|
@ -293,6 +296,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -334,7 +338,7 @@ class TestNLPQueryServiceIntegration:
|
|||
graphql_generation_template="custom-graphql-generator"
|
||||
)
|
||||
|
||||
custom_processor.schemas = sample_schemas
|
||||
custom_processor.schemas = {"default": dict(sample_schemas)}
|
||||
custom_processor.client = MagicMock()
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
|
|
@ -348,6 +352,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -394,7 +399,7 @@ class TestNLPQueryServiceIntegration:
|
|||
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
|
||||
)
|
||||
|
||||
integration_processor.schemas.update(large_schema_set)
|
||||
integration_processor.schemas["default"].update(large_schema_set)
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Show me data from table_05 and table_12",
|
||||
|
|
@ -407,6 +412,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -462,6 +468,7 @@ class TestNLPQueryServiceIntegration:
|
|||
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
|
||||
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
@ -532,6 +539,7 @@ class TestNLPQueryServiceIntegration:
|
|||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
|
|
|
|||
|
|
@ -185,6 +185,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
return context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -197,20 +198,21 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Assert
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" in processor.schemas
|
||||
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 2
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" in ws_schemas
|
||||
|
||||
# Verify customer schema
|
||||
customer_schema = processor.schemas["customer_records"]
|
||||
customer_schema = ws_schemas["customer_records"]
|
||||
assert customer_schema.name == "customer_records"
|
||||
assert len(customer_schema.fields) == 4
|
||||
|
||||
|
||||
# Verify product schema
|
||||
product_schema = processor.schemas["product_catalog"]
|
||||
product_schema = ws_schemas["product_catalog"]
|
||||
assert product_schema.name == "product_catalog"
|
||||
assert len(product_schema.fields) == 4
|
||||
|
||||
|
|
@ -237,12 +239,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create realistic customer data chunk
|
||||
metadata = Metadata(
|
||||
id="customer-doc-001",
|
||||
user="integration_test",
|
||||
collection="test_documents",
|
||||
)
|
||||
|
||||
|
|
@ -304,12 +305,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create realistic product data chunk
|
||||
metadata = Metadata(
|
||||
id="product-doc-001",
|
||||
user="integration_test",
|
||||
collection="test_documents",
|
||||
)
|
||||
|
||||
|
|
@ -368,7 +368,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create multiple test chunks
|
||||
chunks_data = [
|
||||
|
|
@ -382,7 +382,6 @@ class TestObjectExtractionServiceIntegration:
|
|||
for chunk_id, text in chunks_data:
|
||||
metadata = Metadata(
|
||||
id=chunk_id,
|
||||
user="concurrent_test",
|
||||
collection="test_collection",
|
||||
)
|
||||
chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8'))
|
||||
|
|
@ -431,19 +430,21 @@ class TestObjectExtractionServiceIntegration:
|
|||
"customer_records": integration_config["schema"]["customer_records"]
|
||||
}
|
||||
}
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
|
||||
assert len(processor.schemas) == 1
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" not in processor.schemas
|
||||
|
||||
await processor.on_schema_config("default", initial_config, version=1)
|
||||
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 1
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" not in ws_schemas
|
||||
|
||||
# Act - Reload with full configuration
|
||||
await processor.on_schema_config(integration_config, version=2)
|
||||
|
||||
await processor.on_schema_config("default", integration_config, version=2)
|
||||
|
||||
# Assert
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer_records" in processor.schemas
|
||||
assert "product_catalog" in processor.schemas
|
||||
ws_schemas = processor.schemas["default"]
|
||||
assert len(ws_schemas) == 2
|
||||
assert "customer_records" in ws_schemas
|
||||
assert "product_catalog" in ws_schemas
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_resilience_integration(self, integration_config):
|
||||
|
|
@ -474,13 +475,14 @@ class TestObjectExtractionServiceIntegration:
|
|||
return AsyncMock()
|
||||
|
||||
failing_flow.side_effect = failing_context_router
|
||||
failing_flow.workspace = "default"
|
||||
processor.flow = failing_flow
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create test chunk
|
||||
metadata = Metadata(id="error-test", user="test", collection="test")
|
||||
metadata = Metadata(id="error-test", collection="test")
|
||||
chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
|
|
@ -510,12 +512,11 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
await processor.on_schema_config(integration_config, version=1)
|
||||
await processor.on_schema_config("default", integration_config, version=1)
|
||||
|
||||
# Create chunk with rich metadata
|
||||
original_metadata = Metadata(
|
||||
id="metadata-test-chunk",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -544,6 +545,5 @@ class TestObjectExtractionServiceIntegration:
|
|||
assert extracted_obj is not None
|
||||
|
||||
# Verify metadata propagation
|
||||
assert extracted_obj.metadata.user == "test_user"
|
||||
assert extracted_obj.metadata.collection == "test_collection"
|
||||
assert "metadata-test-chunk" in extracted_obj.metadata.id # Should include source reference
|
||||
|
|
@ -87,6 +87,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -109,7 +110,7 @@ class TestPromptStreaming:
|
|||
def prompt_processor_streaming(self, mock_prompt_manager):
|
||||
"""Create Prompt processor with streaming support"""
|
||||
processor = MagicMock()
|
||||
processor.manager = mock_prompt_manager
|
||||
processor.managers = {"default": mock_prompt_manager}
|
||||
processor.config_key = "prompt"
|
||||
|
||||
# Bind the actual on_request method
|
||||
|
|
@ -248,6 +249,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
|
||||
request = PromptRequest(
|
||||
id="test_prompt",
|
||||
|
|
@ -341,6 +343,7 @@ class TestPromptStreaming:
|
|||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
context.workspace = "default"
|
||||
|
||||
request = PromptRequest(
|
||||
id="test_prompt",
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -108,7 +107,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -137,7 +135,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -162,7 +159,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -188,7 +184,6 @@ class TestGraphRagStreamingProtocol:
|
|||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -267,7 +262,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
|
|
@ -290,7 +284,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
@ -314,7 +307,6 @@ class TestDocumentRagStreamingProtocol:
|
|||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
|
|
|
|||
|
|
@ -14,6 +14,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
|
|||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
|
||||
|
||||
class _MockFlowDefault:
|
||||
"""Mock Flow with default workspace for testing."""
|
||||
workspace = "default"
|
||||
name = "default"
|
||||
id = "test-processor"
|
||||
|
||||
|
||||
mock_flow_default = _MockFlowDefault()
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRowsCassandraIntegration:
|
||||
"""Integration tests for Cassandra row storage with unified table"""
|
||||
|
|
@ -125,14 +136,13 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
assert "customer_records" in processor.schemas["default"]
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="doc-001",
|
||||
user="test_user",
|
||||
collection="import_2024",
|
||||
),
|
||||
schema_name="customer_records",
|
||||
|
|
@ -149,7 +159,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
assert mock_cluster.connect.called
|
||||
|
|
@ -158,7 +168,7 @@ class TestRowsCassandraIntegration:
|
|||
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE KEYSPACE" in str(call)]
|
||||
assert len(keyspace_calls) == 1
|
||||
assert "test_user" in str(keyspace_calls[0])
|
||||
assert "default" in str(keyspace_calls[0])
|
||||
|
||||
# Verify unified table creation (rows table, not per-schema table)
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -209,12 +219,12 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
assert len(processor.schemas["default"]) == 2
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog"),
|
||||
metadata=Metadata(id="p1", collection="catalog"),
|
||||
schema_name="products",
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -222,7 +232,7 @@ class TestRowsCassandraIntegration:
|
|||
)
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales"),
|
||||
metadata=Metadata(id="o1", collection="sales"),
|
||||
schema_name="orders",
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
|
|
@ -233,7 +243,7 @@ class TestRowsCassandraIntegration:
|
|||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# All data goes into the same unified rows table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -256,18 +266,20 @@ class TestRowsCassandraIntegration:
|
|||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Schema with multiple indexed fields
|
||||
processor.schemas["indexed_data"] = RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"indexed_data": RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
metadata=Metadata(id="t1", collection="test"),
|
||||
schema_name="indexed_data",
|
||||
values=[{
|
||||
"id": "123",
|
||||
|
|
@ -282,7 +294,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 data inserts (one per indexed field: id, category, status)
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -342,13 +354,12 @@ class TestRowsCassandraIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
|
|
@ -376,7 +387,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify unified table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -396,14 +407,16 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"empty_test": RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
}
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty"),
|
||||
metadata=Metadata(id="empty-1", collection="empty"),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
|
|
@ -413,7 +426,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should not create any data insert statements for empty batch
|
||||
# (partition registration may still happen)
|
||||
|
|
@ -428,17 +441,19 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["map_test"] = RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"map_test": RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
metadata=Metadata(id="t1", collection="test"),
|
||||
schema_name="map_test",
|
||||
values=[{"id": "123", "name": "Test Item", "count": "42"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -448,7 +463,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify insert uses map for data
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
|
|
@ -473,16 +488,18 @@ class TestRowsCassandraIntegration:
|
|||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["partition_test"] = RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
"partition_test": RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="my_collection"),
|
||||
metadata=Metadata(id="t1", collection="my_collection"),
|
||||
schema_name="partition_test",
|
||||
values=[{"id": "123", "category": "test"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -492,7 +509,7 @@ class TestRowsCassandraIntegration:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify partition registration
|
||||
partition_inserts = [call for call in mock_session.execute.call_args_list
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
|
||||
"""Test schema configuration loading and GraphQL schema generation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Verify schemas were loaded
|
||||
assert len(processor.schemas) == 2
|
||||
|
|
@ -181,7 +181,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
|
||||
"""Test Cassandra connection and dynamic table creation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Connect to Cassandra
|
||||
processor.connect_cassandra()
|
||||
|
|
@ -218,7 +218,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
|
||||
"""Test inserting data and querying via GraphQL"""
|
||||
# Load schema and connect
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Setup test data
|
||||
|
|
@ -292,7 +292,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
|
||||
"""Test GraphQL queries with filtering on indexed fields"""
|
||||
# Setup (reuse previous setup)
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "test_user"
|
||||
|
|
@ -353,7 +353,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_graphql_error_handling(self, processor, sample_schema_config):
|
||||
"""Test GraphQL error handling for invalid queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Test invalid field query
|
||||
invalid_query = '''
|
||||
|
|
@ -386,7 +386,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_message_processing_integration(self, processor, sample_schema_config):
|
||||
"""Test full message processing workflow"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
|
|
@ -432,7 +432,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_concurrent_queries(self, processor, sample_schema_config):
|
||||
"""Test handling multiple concurrent GraphQL queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create multiple query tasks
|
||||
|
|
@ -476,7 +476,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
await processor.on_schema_config("default", initial_config, version=1)
|
||||
assert len(processor.schemas) == 1
|
||||
assert "simple" in processor.schemas
|
||||
|
||||
|
|
@ -500,7 +500,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(updated_config, version=2)
|
||||
await processor.on_schema_config("default", updated_config, version=2)
|
||||
|
||||
# Verify updated schemas
|
||||
assert len(processor.schemas) == 2
|
||||
|
|
@ -518,7 +518,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
async def test_large_result_set_handling(self, processor, sample_schema_config):
|
||||
"""Test handling of large query result sets"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "large_test_user"
|
||||
|
|
@ -601,7 +601,7 @@ class TestObjectsGraphQLQueryPerformance:
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
await processor.on_schema_config("default", schema_config, version=1)
|
||||
|
||||
# Measure query execution time
|
||||
start_time = time.time()
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ class TestStructuredQueryServiceIntegration:
|
|||
# Arrange - Create realistic query request
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me all customers from California who have made purchases over $500",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
|
||||
|
|
@ -126,7 +125,6 @@ class TestStructuredQueryServiceIntegration:
|
|||
assert "orders" in objects_call_args.query
|
||||
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
||||
assert objects_call_args.variables["state"] == "California"
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue