feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)

Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.

Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
  proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
  captures the workspace/collection/flow hierarchy.

Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
  DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
  Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
  service layer.
- Translators updated to not serialise/deserialise user.

API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.

Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
  scoped by workspace. Config client API takes workspace as first
  positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
  no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.

CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
  library) drop user kwargs from every method signature.

MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
  keyed per user.

Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
  whose blueprint template was parameterised AND no remaining
  live flow (across all workspaces) still resolves to that topic.
  Three scopes fall out naturally from template analysis:
    * {id} -> per-flow, deleted on stop
    * {blueprint} -> per-blueprint, kept while any flow of the
      same blueprint exists
    * {workspace} -> per-workspace, kept while any flow in the
      workspace exists
    * literal -> global, never deleted (e.g. tg.request.librarian)
  Fixes a bug where stopping a flow silently destroyed the global
  librarian exchange, wedging all library operations until manual
  restart.

RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
  dead connections (broker restart, orphaned channels, network
  partitions) within ~2 heartbeat windows, so the consumer
  reconnects and re-binds its queue rather than sitting forever
  on a zombie connection.

Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
  ~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
cybermaggedon 2026-04-21 23:23:01 +01:00 committed by GitHub
parent 9332089b3d
commit d35473f7f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
377 changed files with 6868 additions and 5785 deletions

View file

@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming:
msg = MagicMock()
msg.value.return_value = AgentRequest(
question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()
@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming:
msg = MagicMock()
msg.value.return_value = AgentRequest(
question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()

View file

@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(question="Test question", user="testuser",
def _make_request(question="Test question",
collection="default", streaming=False,
session_id="parent-session", task_type="research",
framing="test framing", conversation_id="conv-1"):
return AgentRequest(
question=question,
user=user,
collection=collection,
streaming=streaming,
session_id=session_id,
@ -127,7 +126,6 @@ class TestBuildSynthesisRequest:
req = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
@ -148,7 +146,7 @@ class TestBuildSynthesisRequest:
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# Last history step should be the synthesis step
@ -168,7 +166,7 @@ class TestBuildSynthesisRequest:
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# Entry should be removed
@ -178,7 +176,7 @@ class TestBuildSynthesisRequest:
agg = Aggregator()
with pytest.raises(RuntimeError, match="No results"):
agg.build_synthesis_request(
"unknown", "question", "user", "default",
"unknown", "question", "default",
)

View file

@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
@ -130,7 +129,6 @@ class TestAggregatorIntegration:
synth = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
@ -160,7 +158,7 @@ class TestAggregatorIntegration:
agg.record_completion("corr-1", "goal", "answer")
synth = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# correlation_id must be empty so it's not intercepted

View file

@ -126,7 +126,6 @@ def make_base_request(**kwargs):
state="",
group=[],
history=[],
user="testuser",
collection="default",
streaming=False,
session_id="test-session-123",

View file

@ -21,7 +21,6 @@ class MockProcessor:
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)

View file

@ -167,39 +167,28 @@ class TestToolServiceRequest:
"""Test cases for tool service request format"""
def test_request_format(self):
"""Test that request is properly formatted with user, config, and arguments"""
# Arrange
user = "alice"
"""Test that request is properly formatted with config and arguments"""
config_values = {"style": "pun", "collection": "jokes"}
arguments = {"topic": "programming"}
# Act - simulate request building
request = {
"user": user,
"config": json.dumps(config_values),
"arguments": json.dumps(arguments)
}
# Assert
assert request["user"] == "alice"
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
assert json.loads(request["arguments"]) == {"topic": "programming"}
def test_request_with_empty_config(self):
"""Test request when no config values are provided"""
# Arrange
user = "bob"
config_values = {}
arguments = {"query": "test"}
# Act
request = {
"user": user,
"config": json.dumps(config_values) if config_values else "{}",
"arguments": json.dumps(arguments) if arguments else "{}"
}
# Assert
assert request["config"] == "{}"
assert json.loads(request["arguments"]) == {"query": "test"}
@ -386,18 +375,13 @@ class TestJokeServiceLogic:
assert map_topic_to_category("random topic") == "default"
assert map_topic_to_category("") == "default"
def test_joke_response_personalization(self):
"""Test that joke responses include user personalization"""
# Arrange
user = "alice"
def test_joke_response_format(self):
"""Test that joke response is formatted as expected"""
style = "pun"
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
# Act
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
response = f"Here's a {style} for you:\n\n{joke}"
# Assert
assert "Hey alice!" in response
assert "pun" in response
assert joke in response
@ -439,20 +423,14 @@ class TestDynamicToolServiceBase:
def test_request_parsing(self):
"""Test parsing of incoming request"""
# Arrange
request_data = {
"user": "alice",
"config": '{"style": "pun"}',
"arguments": '{"topic": "programming"}'
}
# Act
user = request_data.get("user", "trustgraph")
config = json.loads(request_data["config"]) if request_data["config"] else {}
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
# Assert
assert user == "alice"
assert config == {"style": "pun"}
assert arguments == {"topic": "programming"}

View file

@ -1,6 +1,6 @@
"""
Tests for tool service lifecycle, invoke contract, streaming responses,
multi-tenancy, and error propagation.
and error propagation.
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
classes rather than plain dicts.
@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract:
svc = DynamicToolService.__new__(DynamicToolService)
with pytest.raises(NotImplementedError):
await svc.invoke("user", {}, {})
await svc.invoke({}, {})
@pytest.mark.asyncio
async def test_on_request_calls_invoke_with_parsed_args(self):
@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract:
calls = []
async def tracking_invoke(user, config, arguments):
calls.append({"user": user, "config": config, "arguments": arguments})
async def tracking_invoke(config, arguments):
calls.append({"config": config, "arguments": arguments})
return "ok"
svc.invoke = tracking_invoke
@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user="alice",
config='{"style": "pun"}',
arguments='{"topic": "cats"}',
)
@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract:
await svc.on_request(msg, MagicMock(), None)
assert len(calls) == 1
assert calls[0]["user"] == "alice"
assert calls[0]["config"] == {"style": "pun"}
assert calls[0]["arguments"] == {"topic": "cats"}
@pytest.mark.asyncio
async def test_on_request_empty_user_defaults_to_trustgraph(self):
"""Empty user field should default to 'trustgraph'."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
received_user = None
async def capture_invoke(user, config, arguments):
nonlocal received_user
received_user = user
return "ok"
svc.invoke = capture_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
msg.properties.return_value = {"id": "req-2"}
await svc.on_request(msg, MagicMock(), None)
assert received_user == "trustgraph"
@pytest.mark.asyncio
async def test_on_request_string_response_sent_directly(self):
"""String return from invoke → response field is the string."""
@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def string_invoke(user, config, arguments):
async def string_invoke(config, arguments):
return "hello world"
svc.invoke = string_invoke
@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r1"}
await svc.on_request(msg, MagicMock(), None)
@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def dict_invoke(user, config, arguments):
async def dict_invoke(config, arguments):
return {"result": 42}
svc.invoke = dict_invoke
@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r2"}
await svc.on_request(msg, MagicMock(), None)
@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def failing_invoke(user, config, arguments):
async def failing_invoke(config, arguments):
raise ValueError("bad input")
svc.invoke = failing_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r3"}
await svc.on_request(msg, MagicMock(), None)
@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def rate_limited_invoke(user, config, arguments):
async def rate_limited_invoke(config, arguments):
raise TooManyRequests("rate limited")
svc.invoke = rate_limited_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r4"}
with pytest.raises(TooManyRequests):
@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def ok_invoke(user, config, arguments):
async def ok_invoke(config, arguments):
return "ok"
svc.invoke = ok_invoke
@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "unique-42"}
await svc.on_request(msg, MagicMock(), None)
@ -241,7 +210,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return "tool result"
svc.invoke_tool = mock_invoke
@ -260,6 +229,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
@ -280,7 +250,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke
@ -298,6 +268,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -317,7 +288,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def failing_invoke(name, params):
async def failing_invoke(workspace, name, params):
raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke
@ -330,6 +301,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -350,7 +322,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def rate_limited(name, params):
async def rate_limited(workspace, name, params):
raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited
@ -362,6 +334,7 @@ class TestToolServiceOnRequest:
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
flow.name = "test-flow"
flow.workspace = "default"
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow)
@ -376,7 +349,8 @@ class TestToolServiceOnRequest:
received = {}
async def capture_invoke(name, params):
async def capture_invoke(workspace, name, params):
received["workspace"] = workspace
received["name"] = name
received["params"] = params
return "ok"
@ -390,6 +364,7 @@ class TestToolServiceOnRequest:
flow = lambda name: mock_pub
flow.producer = {"response": mock_pub}
flow.name = "f"
flow.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(
@ -421,7 +396,6 @@ class TestToolServiceClientCall:
))
result = await client.call(
user="alice",
config={"style": "pun"},
arguments={"topic": "cats"},
)
@ -430,7 +404,6 @@ class TestToolServiceClientCall:
req = client.request.call_args[0][0]
assert isinstance(req, ToolServiceRequest)
assert req.user == "alice"
assert json.loads(req.config) == {"style": "pun"}
assert json.loads(req.arguments) == {"topic": "cats"}
@ -446,7 +419,7 @@ class TestToolServiceClientCall:
))
with pytest.raises(RuntimeError, match="service down"):
await client.call(user="u", config={}, arguments={})
await client.call(config={}, arguments={})
@pytest.mark.asyncio
async def test_call_empty_config_sends_empty_json(self):
@ -458,7 +431,7 @@ class TestToolServiceClientCall:
error=None, response="ok",
))
await client.call(user="u", config=None, arguments=None)
await client.call(config=None, arguments=None)
req = client.request.call_args[0][0]
assert req.config == "{}"
@ -474,7 +447,7 @@ class TestToolServiceClientCall:
error=None, response="ok",
))
await client.call(user="u", config={}, arguments={}, timeout=30)
await client.call(config={}, arguments={}, timeout=30)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 30
@ -509,7 +482,7 @@ class TestToolServiceClientStreaming:
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
config={}, arguments={}, callback=callback,
)
assert result == "chunk1chunk2"
@ -534,7 +507,7 @@ class TestToolServiceClientStreaming:
with pytest.raises(RuntimeError, match="stream failed"):
await client.call_streaming(
user="u", config={}, arguments={},
config={}, arguments={},
callback=AsyncMock(),
)
@ -564,61 +537,9 @@ class TestToolServiceClientStreaming:
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
config={}, arguments={}, callback=callback,
)
# Empty response is falsy, so callback shouldn't be called for it
assert result == "data"
assert received == ["data"]
# ---------------------------------------------------------------------------
# Multi-tenancy
# ---------------------------------------------------------------------------
class TestMultiTenancy:
@pytest.mark.asyncio
async def test_user_propagated_to_invoke(self):
"""User from request should reach the invoke method."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test"
svc.producer = AsyncMock()
users_seen = []
async def tracking(user, config, arguments):
users_seen.append(user)
return "ok"
svc.invoke = tracking
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
for u in ["tenant-a", "tenant-b", "tenant-c"]:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user=u, config="{}", arguments="{}",
)
msg.properties.return_value = {"id": f"req-{u}"}
await svc.on_request(msg, MagicMock(), None)
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
@pytest.mark.asyncio
async def test_client_sends_user_in_request(self):
"""ToolServiceClient.call should include user in request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="isolated-tenant", config={}, arguments={})
req = client.request.call_args[0][0]
assert req.user == "isolated-tenant"

View file

@ -1,17 +1,14 @@
"""
Tests for AsyncProcessor config notify pattern:
- register_config_handler with types filtering
- on_config_notify version comparison and type matching
- fetch_config with short-lived client
- fetch_and_apply_config retry logic
- on_config_notify version comparison, type/workspace matching
- fetch_and_apply_config retry logic over per-workspace fetches
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, Mock
from trustgraph.schema import Term, IRI, LITERAL
# Patch heavy dependencies before importing AsyncProcessor
@pytest.fixture
def processor():
"""Create an AsyncProcessor with mocked dependencies."""
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
assert len(processor.config_handlers) == 2
def _notify_msg(version, changes):
"""Build a Mock config-notify message with given version and changes dict."""
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestOnConfigNotify:
@pytest.mark.asyncio
@ -77,9 +81,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=3, types=["prompt"])
msg = _notify_msg(3, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -91,9 +93,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=5, types=["prompt"])
msg = _notify_msg(5, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -105,9 +105,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=2, types=["schema"])
msg = _notify_msg(2, {"schema": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -121,40 +119,36 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
# Mock fetch_config
mock_config = {"prompt": {"key": "value"}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={"key": "value"},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_called_once_with(
"default", {"prompt": {"key": "value"}}, 2
)
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_handler_without_types_always_called(self, processor):
async def test_handler_without_types_ignored_on_notify(self, processor):
"""Handlers registered without types never fire on notifications."""
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler) # No types = all
processor.register_config_handler(handler) # No types
mock_config = {"anything": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["whatever"])
msg = _notify_msg(2, {"whatever": ["default"]})
await processor.on_config_notify(msg, None, None)
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_not_called()
# Version still advances past the notify
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_mixed_handlers_type_filtering(self, processor):
@ -168,156 +162,149 @@ class TestOnConfigNotify:
processor.register_config_handler(schema_handler, types=["schema"])
processor.register_config_handler(all_handler)
mock_config = {"prompt": {}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
prompt_handler.assert_called_once()
prompt_handler.assert_called_once_with(
"default", {"prompt": {}}, 2
)
schema_handler.assert_not_called()
all_handler.assert_called_once()
all_handler.assert_not_called()
@pytest.mark.asyncio
async def test_empty_types_invokes_all(self, processor):
"""Empty types list (startup signal) should invoke all handlers."""
async def test_multi_workspace_notify_invokes_handler_per_ws(
self, processor
):
"""Notify affecting multiple workspaces invokes handler once per workspace."""
processor.config_version = 1
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
mock_config = {}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=[])
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
await processor.on_config_notify(msg, None, None)
h1.assert_called_once()
h2.assert_called_once()
assert handler.call_count == 2
called_workspaces = {c.args[0] for c in handler.call_args_list}
assert called_workspaces == {"ws1", "ws2"}
@pytest.mark.asyncio
async def test_fetch_failure_handled(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler)
processor.register_config_handler(handler, types=["prompt"])
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
side_effect=RuntimeError("Connection failed")
side_effect=RuntimeError("Connection failed"),
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
# Should not raise
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
class TestFetchConfig:
@pytest.mark.asyncio
async def test_fetch_returns_config_and_version(self, processor):
mock_resp = Mock()
mock_resp.error = None
mock_resp.config = {"prompt": {"key": "val"}}
mock_resp.version = 42
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
config, version = await processor.fetch_config()
assert config == {"prompt": {"key": "val"}}
assert version == 42
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_raises_on_error_response(self, processor):
mock_resp = Mock()
mock_resp.error = Mock(message="not found")
mock_resp.config = {}
mock_resp.version = 0
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(RuntimeError, match="Config error"):
await processor.fetch_config()
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_stops_client_on_exception(self, processor):
mock_client = AsyncMock()
mock_client.request.side_effect = TimeoutError("timeout")
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(TimeoutError):
await processor.fetch_config()
mock_client.stop.assert_called_once()
class TestFetchAndApplyConfig:
@pytest.mark.asyncio
async def test_applies_config_to_all_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
async def test_applies_config_per_workspace(self, processor):
"""Startup fetch invokes handler once per workspace affected."""
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {
"ws1": {"k": "v1"},
"ws2": {"k": "v2"},
}, 10
mock_config = {"prompt": {}, "schema": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 10)
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
# On startup, all handlers are invoked regardless of type
h1.assert_called_once_with(mock_config, 10)
h2.assert_called_once_with(mock_config, 10)
assert h.call_count == 2
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
assert processor.config_version == 10
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
call_count = 0
mock_config = {"prompt": {}}
async def test_handler_without_types_skipped_at_startup(self, processor):
"""Handlers registered without types fetch nothing at startup."""
typed = AsyncMock()
untyped = AsyncMock()
processor.register_config_handler(typed, types=["prompt"])
processor.register_config_handler(untyped)
async def mock_fetch():
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {"default": {}}, 1
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
typed.assert_called_once()
untyped.assert_not_called()
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
call_count = 0
async def fake_fetch_all(client, config_type):
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("not ready")
return mock_config, 5
return {"default": {"k": "v"}}, 5
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
patch('asyncio.sleep', new_callable=AsyncMock):
mock_client = AsyncMock()
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
), patch('asyncio.sleep', new_callable=AsyncMock):
await processor.fetch_and_apply_config()
assert call_count == 3
assert processor.config_version == 5
h.assert_called_once_with(
"default", {"prompt": {"k": "v"}}, 5
)

View file

@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
result = await client.query(
vector=vector,
limit=10,
user="test_user",
collection="test_collection",
timeout=30
)
@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vector == vector
assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection"
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert call_args.limit == 20 # Default limit
assert call_args.user == "trustgraph" # Default user
assert call_args.collection == "default" # Default collection
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')

View file

@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs():
spec_two = MagicMock()
processor = MagicMock(specifications=[spec_one, spec_two])
flow = Flow("processor-1", "flow-a", processor, {"answer": 42})
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
assert flow.id == "processor-1"
assert flow.name == "flow-a"
assert flow.workspace == "default"
assert flow.producer == {}
assert flow.consumer == {}
assert flow.parameter == {}
@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs():
def test_flow_start_and_stop_visit_all_consumers():
consumer_one = AsyncMock()
consumer_two = AsyncMock()
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.consumer = {"one": consumer_one, "two": consumer_two}
asyncio.run(flow.start())
@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers():
def test_flow_call_returns_values_in_priority_order():
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.producer["shared"] = "producer-value"
flow.consumer["consumer-only"] = "consumer-value"
flow.consumer["shared"] = "consumer-value"

View file

@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
# The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation)

View file

@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
assert flow_name in processor.flows
assert ("default", flow_name) in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', flow_name, processor, flow_defn
'test-processor', flow_name, "default", processor, flow_defn
)
mock_flow.start.assert_called_once()
@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
await processor.start_flow(flow_name, {'config': 'test-config'})
await processor.start_flow("default", flow_name, {'config': 'test-config'})
await processor.stop_flow(flow_name)
await processor.stop_flow("default", flow_name)
assert flow_name not in processor.flows
assert ("default", flow_name) not in processor.flows
mock_flow.stop.assert_called_once()
@with_async_processor_patches
@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
processor = FlowProcessor(**config)
await processor.stop_flow('non-existent-flow')
await processor.stop_flow("default", 'non-existent-flow')
assert processor.flows == {}
@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert 'test-flow' in processor.flows
assert ("default", 'test-flow') in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', 'test-flow', processor,
'test-processor', 'test-flow', "default", processor,
{'config': 'test-config'}
)
mock_flow.start.assert_called_once()
@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
'other-data': 'some-value'
}
await processor.on_configure_flows(config_data, version=1)
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data1, version=1)
await processor.on_configure_flows("default", config_data1, version=1)
config_data2 = {
'processor:test-processor': {
@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
}
}
await processor.on_configure_flows(config_data2, version=2)
await processor.on_configure_flows("default", config_data2, version=2)
assert 'flow1' not in processor.flows
assert ("default", 'flow1') not in processor.flows
mock_flow1.stop.assert_called_once()
assert 'flow2' in processor.flows
assert ("default", 'flow2') in processor.flows
mock_flow2.start.assert_called_once()
@with_async_processor_patches

View file

@ -28,7 +28,6 @@ def sample_text_document():
"""Sample document with moderate length text."""
metadata = Metadata(
id="test-doc-1",
user="test-user",
collection="test-collection"
)
text = "The quick brown fox jumps over the lazy dog. " * 20
@ -43,7 +42,6 @@ def long_text_document():
"""Long document for testing multiple chunks."""
metadata = Metadata(
id="test-doc-long",
user="test-user",
collection="test-collection"
)
# Create a long text that will definitely be chunked
@ -59,7 +57,6 @@ def unicode_text_document():
"""Document with various unicode characters."""
metadata = Metadata(
id="test-doc-unicode",
user="test-user",
collection="test-collection"
)
text = """
@ -84,7 +81,6 @@ def empty_text_document():
"""Empty document for edge case testing."""
metadata = Metadata(
id="test-doc-empty",
user="test-user",
collection="test-collection"
)
return TextDocument(

View file

@ -185,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-123",
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content"

View file

@ -185,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-456",
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content for token chunking"

View file

@ -109,7 +109,8 @@ class TestListConfigItems:
url='http://custom.com',
config_type='prompt',
format_type='json',
token=None
token=None,
workspace='default'
)
def test_list_main_uses_defaults(self):
@ -128,7 +129,8 @@ class TestListConfigItems:
url='http://localhost:8088/',
config_type='prompt',
format_type='text',
token=None
token=None,
workspace='default'
)
@ -196,7 +198,8 @@ class TestGetConfigItem:
config_type='prompt',
key='template-1',
format_type='json',
token=None
token=None,
workspace='default'
)
@ -253,7 +256,8 @@ class TestPutConfigItem:
config_type='prompt',
key='new-template',
value='Custom prompt: {input}',
token=None
token=None,
workspace='default'
)
def test_put_main_with_stdin_arg(self):
@ -278,7 +282,8 @@ class TestPutConfigItem:
config_type='prompt',
key='stdin-template',
value=stdin_content,
token=None
token=None,
workspace='default'
)
def test_put_main_mutually_exclusive_args(self):
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
url='http://custom.com',
config_type='prompt',
key='old-template',
token=None
token=None,
workspace='default'
)

View file

@ -48,7 +48,7 @@ def knowledge_loader():
return KnowledgeLoader(
files=["test.ttl"],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc-123",
url="http://test.example.com/",
@ -64,7 +64,7 @@ class TestKnowledgeLoader:
loader = KnowledgeLoader(
files=["file1.ttl", "file2.ttl"],
flow="my-flow",
user="user1",
workspace="user1",
collection="col1",
document_id="doc1",
url="http://example.com/",
@ -73,7 +73,7 @@ class TestKnowledgeLoader:
assert loader.files == ["file1.ttl", "file2.ttl"]
assert loader.flow == "my-flow"
assert loader.user == "user1"
assert loader.workspace == "user1"
assert loader.collection == "col1"
assert loader.document_id == "doc1"
assert loader.url == "http://example.com/"
@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob .
loader = KnowledgeLoader(
files=[f.name],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/"
@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob .
loader = KnowledgeLoader(
files=[temp_turtle_file],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/",
@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
# Verify Api was created with correct parameters
mock_api_class.assert_called_once_with(
url="http://test.example.com/",
token="test-token"
token="test-token",
workspace="test-user"
)
# Verify bulk client was obtained
@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob .
call_args = mock_bulk.import_triples.call_args
assert call_args[1]['flow'] == "test-flow"
assert call_args[1]['metadata']['id'] == "test-doc"
assert call_args[1]['metadata']['user'] == "test-user"
assert call_args[1]['metadata']['collection'] == "test-collection"
# Verify import_entity_contexts was called
@ -198,7 +198,7 @@ class TestCLIArgumentParsing:
'tg-load-knowledge',
'-i', 'doc-123',
'-f', 'my-flow',
'-U', 'my-user',
'-w', 'my-user',
'-C', 'my-collection',
'-u', 'http://custom.example.com/',
'-t', 'my-token',
@ -216,7 +216,7 @@ class TestCLIArgumentParsing:
token='my-token',
flow='my-flow',
files=['file1.ttl', 'file2.ttl'],
user='my-user',
workspace='my-user',
collection='my-collection'
)
@ -242,7 +242,7 @@ class TestCLIArgumentParsing:
# Verify defaults were used
call_args = mock_loader_class.call_args[1]
assert call_args['flow'] == 'default'
assert call_args['user'] == 'trustgraph'
assert call_args['workspace'] == 'default'
assert call_args['collection'] == 'default'
assert call_args['url'] == 'http://localhost:8088/'
assert call_args['token'] is None
@ -287,7 +287,7 @@ class TestErrorHandling:
loader = KnowledgeLoader(
files=[temp_turtle_file],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/"

View file

@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_set_main_structured_query_no_arguments_needed(self):
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_valid_types_includes_row_embeddings_query(self):
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
show_main()
mock_show.assert_called_once_with(url='http://custom.com', token=None)
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
class TestShowToolsRowEmbeddingsQuery:

View file

@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient:
# Act
result = client.request(
vector=vector,
user="test_user",
collection="test_collection",
limit=10,
timeout=300
@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient:
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with(
user="test_user",
collection="test_collection",
vector=vector,
limit=10,
@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient:
# Assert
assert result == ["test_chunk"]
client.call.assert_called_once_with(
user="trustgraph",
collection="default",
vector=vector,
limit=10,

View file

@ -31,7 +31,6 @@ def _make_query(
query = Query(
rag=rag,
user="test-user",
collection="test-collection",
verbose=False,
entity_limit=entity_limit,
@ -208,7 +207,6 @@ class TestBatchTripleQueries:
assert calls[0].kwargs["p"] is None
assert calls[0].kwargs["o"] is None
assert calls[0].kwargs["limit"] == 15
assert calls[0].kwargs["user"] == "test-user"
assert calls[0].kwargs["collection"] == "test-collection"
assert calls[0].kwargs["batch_size"] == 20

View file

@ -28,10 +28,12 @@ def mock_flow_config():
"""Mock flow configuration."""
mock_config = Mock()
mock_config.flows = {
"test-flow": {
"interfaces": {
"triples-store": {"flow": "test-triples-queue"},
"graph-embeddings-store": {"flow": "test-ge-queue"}
"test-user": {
"test-flow": {
"interfaces": {
"triples-store": {"flow": "test-triples-queue"},
"graph-embeddings-store": {"flow": "test-ge-queue"}
}
}
}
}
@ -43,7 +45,7 @@ def mock_flow_config():
def mock_request():
"""Mock knowledge load request."""
request = Mock()
request.user = "test-user"
request.workspace = "test-user"
request.id = "test-doc-id"
request.collection = "test-collection"
request.flow = "test-flow"
@ -71,7 +73,6 @@ def sample_triples():
return Triples(
metadata=Metadata(
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
),
triples=[
@ -90,7 +91,6 @@ def sample_graph_embeddings():
return GraphEmbeddings(
metadata=Metadata(
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
),
entities=[
@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1]
assert sent_triples.metadata.collection == "test-collection"
assert sent_triples.metadata.user == "test-user"
assert sent_triples.metadata.id == "test-doc-id"
@pytest.mark.asyncio
@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
mock_ge_pub.send.assert_called_once()
sent_ge = mock_ge_pub.send.call_args[0][1]
assert sent_ge.metadata.collection == "test-collection"
assert sent_ge.metadata.user == "test-user"
assert sent_ge.metadata.id == "test-doc-id"
@pytest.mark.asyncio
@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core falls back to 'default' when request.collection is None."""
# Create request with None collection
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_request.collection = None # Should fall back to "default"
mock_request.flow = "test-flow"
@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core validates flow configuration before processing."""
# Request with invalid flow
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_request.collection = "test-collection"
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore:
# Test missing ID
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = None # Missing
mock_request.collection = "test-collection"
mock_request.flow = "test-flow"
@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
"""Test that get_kg_core preserves collection field from stored data."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_respond = AsyncMock()
@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_list_kg_cores(self, knowledge_manager):
"""Test listing knowledge cores."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_respond = AsyncMock()
@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_delete_kg_core(self, knowledge_manager):
"""Test deleting knowledge cores."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_respond = AsyncMock()

View file

@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message with inline data
content = b"# Document Title\nBody text content."
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
]
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,

View file

@ -12,7 +12,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_basic(self):
"""Test basic collection name creation"""
result = make_safe_collection_name(
user="test_user",
workspace="test_user",
collection="test_collection",
prefix="doc"
)
@ -21,7 +21,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_special_characters(self):
"""Test collection name creation with special characters that need sanitization"""
result = make_safe_collection_name(
user="user@domain.com",
workspace="user@domain.com",
collection="test-collection.v2",
prefix="entity"
)
@ -30,7 +30,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_unicode(self):
"""Test collection name creation with Unicode characters"""
result = make_safe_collection_name(
user="测试用户",
workspace="测试用户",
collection="colección_española",
prefix="doc"
)
@ -39,7 +39,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_spaces(self):
"""Test collection name creation with spaces"""
result = make_safe_collection_name(
user="test user",
workspace="test user",
collection="my test collection",
prefix="entity"
)
@ -48,7 +48,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self):
"""Test collection name creation with multiple consecutive special characters"""
result = make_safe_collection_name(
user="user@@@domain!!!",
workspace="user@@@domain!!!",
collection="test---collection...v2",
prefix="doc"
)
@ -57,7 +57,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_leading_trailing_underscores(self):
"""Test collection name creation with leading/trailing special characters"""
result = make_safe_collection_name(
user="__test_user__",
workspace="__test_user__",
collection="@@test_collection##",
prefix="entity"
)
@ -66,7 +66,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_empty_user(self):
"""Test collection name creation with empty user (should fallback to 'default')"""
result = make_safe_collection_name(
user="",
workspace="",
collection="test_collection",
prefix="doc"
)
@ -75,7 +75,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_empty_collection(self):
"""Test collection name creation with empty collection (should fallback to 'default')"""
result = make_safe_collection_name(
user="test_user",
workspace="test_user",
collection="",
prefix="doc"
)
@ -84,7 +84,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_both_empty(self):
"""Test collection name creation with both user and collection empty"""
result = make_safe_collection_name(
user="",
workspace="",
collection="",
prefix="doc"
)
@ -93,7 +93,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_only_special_characters(self):
"""Test collection name creation with only special characters (should fallback to 'default')"""
result = make_safe_collection_name(
user="@@@!!!",
workspace="@@@!!!",
collection="---###",
prefix="entity"
)
@ -102,7 +102,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_whitespace_only(self):
"""Test collection name creation with whitespace-only strings"""
result = make_safe_collection_name(
user=" \n\t ",
workspace=" \n\t ",
collection=" \r\n ",
prefix="doc"
)
@ -111,7 +111,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_mixed_valid_invalid_chars(self):
"""Test collection name creation with mixed valid and invalid characters"""
result = make_safe_collection_name(
user="user123@test",
workspace="user123@test",
collection="coll_2023.v1",
prefix="entity"
)
@ -147,7 +147,7 @@ class TestMilvusCollectionNaming:
long_collection = "b" * 100
result = make_safe_collection_name(
user=long_user,
workspace=long_user,
collection=long_collection,
prefix="doc"
)
@ -159,7 +159,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_numeric_values(self):
"""Test collection name creation with numeric user/collection values"""
result = make_safe_collection_name(
user="user123",
workspace="user123",
collection="collection456",
prefix="doc"
)
@ -168,7 +168,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_case_sensitivity(self):
"""Test that collection name creation preserves case"""
result = make_safe_collection_name(
user="TestUser",
workspace="TestUser",
collection="TestCollection",
prefix="Doc"
)

View file

@ -20,9 +20,8 @@ def processor():
)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, collection=collection)
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
msg = MagicMock()
msg.value.return_value = value
@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor:
@pytest.mark.asyncio
async def test_metadata_preserved(self, processor):
"""Output should carry the original metadata."""
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
msg = _make_chunk_message(collection="reports", doc_id="d1")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]]
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.metadata.user == "alice"
assert result.metadata.collection == "reports"
assert result.metadata.id == "d1"

View file

@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
def _make_message(entities, doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, collection=collection)
value = EntityContexts(metadata=metadata, entities=entities)
msg = MagicMock()
msg.value.return_value = value
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
_make_entity_context(f"E{i}", f"ctx {i}")
for i in range(5)
]
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
msg = _make_message(entities, doc_id="doc-42", collection="main")
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
mock_output = AsyncMock()
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
for call in mock_output.send.call_args_list:
result = call[0][0]
assert result.metadata.id == "doc-42"
assert result.metadata.user == "alice"
assert result.metadata.collection == "main"
@pytest.mark.asyncio

View file

@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
assert 'customers' in processor.schemas["default"]
assert processor.schemas["default"]['customers'].name == 'customers'
assert len(processor.schemas["default"]['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert processor.schemas == {}
assert processor.schemas.get("default", {}) == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('default', 'test_collection')] = {}
# No schemas registered
metadata = MagicMock()
@ -322,17 +322,19 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('default', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
processor.schemas["default"] = {
'customers': RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
}
metadata = MagicMock()
metadata.user = 'test_user'
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
mock_flow.workspace = "default"
await processor.on_message(mock_msg, MagicMock(), mock_flow)

View file

@ -34,11 +34,10 @@ def _make_defn(entity, definition):
return {"entity": entity, "definition": definition}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
id=meta_id, root=root, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
@ -229,8 +228,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
"text", meta_id="c-1", root="r-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
@ -238,7 +236,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(triples_pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"
@pytest.mark.asyncio
@ -247,8 +244,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-2", root="r-2",
user="u-2", collection="coll-2",
"text", meta_id="c-2", root="r-2", collection="coll-2",
)
await proc.on_message(msg, MagicMock(), flow)

View file

@ -38,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
"""Build a mock message wrapping a Chunk."""
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
id=meta_id, root=root, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
@ -189,8 +188,7 @@ class TestMetadataPreservation:
rels = [_make_rel("X", "rel", "Y")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
"text", meta_id="c-1", root="r-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
@ -198,7 +196,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"

View file

@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
ConfigReceiver.config_loader = Mock()
def _notify(version, changes):
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestConfigReceiver:
"""Test cases for ConfigReceiver class"""
@ -47,98 +53,70 @@ class TestConfigReceiver:
assert handler2 in config_receiver.flow_handlers
@pytest.mark.asyncio
async def test_on_config_notify_new_version(self):
"""Test on_config_notify triggers fetch for newer version"""
async def test_on_config_notify_new_version_fetches_per_workspace(self):
"""Notify with newer version fetches each affected workspace."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
# Mock fetch_and_apply
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with newer version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 1
msg = _notify(2, {"flow": ["ws1", "ws2"]})
await config_receiver.on_config_notify(msg, None, None)
assert set(fetch_calls) == {"ws1", "ws2"}
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_old_version_ignored(self):
"""Test on_config_notify ignores older versions"""
"""Older-version notifies are ignored."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 5
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with older version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=3, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 0
msg = _notify(3, {"flow": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
@pytest.mark.asyncio
async def test_on_config_notify_irrelevant_types_ignored(self):
"""Test on_config_notify ignores types the gateway doesn't care about"""
"""Notifies without flow changes advance version but skip fetch."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with non-flow type
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
# Version should be updated but no fetch
assert len(fetch_calls) == 0
msg = _notify(2, {"prompt": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_flow_type_triggers_fetch(self):
"""Test on_config_notify fetches for flow-related types"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
for type_name in ["flow"]:
fetch_calls.clear()
config_receiver.config_version = 1
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=[type_name])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
@pytest.mark.asyncio
async def test_on_config_notify_exception_handling(self):
"""Test on_config_notify handles exceptions gracefully"""
"""on_config_notify swallows exceptions from message decode."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create notify message that causes an exception
mock_msg = Mock()
mock_msg.value.side_effect = Exception("Test exception")
@ -146,19 +124,18 @@ class TestConfigReceiver:
await config_receiver.on_config_notify(mock_msg, None, None)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_new_flows(self):
"""Test fetch_and_apply starts new flows"""
async def test_fetch_and_apply_workspace_starts_new_flows(self):
"""fetch_and_apply_workspace starts newly-configured flows."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock _create_config_client to return a mock client
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}',
"flow2": '{"name": "test_flow_2"}'
"flow2": '{"name": "test_flow_2"}',
}
}
@ -167,36 +144,39 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
start_flow_calls = []
async def mock_start_flow(id, flow):
start_flow_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_flow_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.config_version == 5
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" in config_receiver.flows["default"]
assert len(start_flow_calls) == 2
assert all(c[0] == "default" for c in start_flow_calls)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_removed_flows(self):
"""Test fetch_and_apply stops removed flows"""
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
"""fetch_and_apply_workspace stops flows no longer configured."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config now only has flow1
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}'
"flow1": '{"name": "test_flow_1"}',
}
}
@ -205,20 +185,22 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
stop_flow_calls = []
async def mock_stop_flow(id, flow):
stop_flow_calls.append((id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_flow_calls.append((workspace, id, flow))
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" not in config_receiver.flows["default"]
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0][0] == "flow2"
assert stop_flow_calls[0][:2] == ("default", "flow2")
@pytest.mark.asyncio
async def test_fetch_and_apply_with_no_flows(self):
"""Test fetch_and_apply with empty config"""
async def test_fetch_and_apply_workspace_with_no_flows(self):
"""Empty workspace config clears any local flow state."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
@ -231,88 +213,100 @@ class TestConfigReceiver:
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.flows == {}
assert config_receiver.flows.get("default", {}) == {}
assert config_receiver.config_version == 1
@pytest.mark.asyncio
async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers"""
"""start_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.start_flow = Mock()
handler1.start_flow = AsyncMock()
handler2 = Mock()
handler2.start_flow = Mock()
handler2.start_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler1.start_flow.assert_called_once_with("flow1", flow_data)
handler2.start_flow.assert_called_once_with("flow1", flow_data)
handler1.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_start_flow_with_handler_exception(self):
"""Test start_flow method handles handler exceptions"""
"""Handler exceptions in start_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.start_flow = Mock(side_effect=Exception("Handler error"))
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler.start_flow.assert_called_once_with("flow1", flow_data)
handler.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handlers(self):
"""Test stop_flow method with multiple handlers"""
"""stop_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.stop_flow = Mock()
handler1.stop_flow = AsyncMock()
handler2 = Mock()
handler2.stop_flow = Mock()
handler2.stop_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
handler1.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handler_exception(self):
"""Test stop_flow method handles handler exceptions"""
"""Handler exceptions in stop_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler.stop_flow.assert_called_once_with("flow1", flow_data)
handler.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@patch('asyncio.create_task')
@pytest.mark.asyncio
@ -329,25 +323,25 @@ class TestConfigReceiver:
mock_create_task.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_and_apply_mixed_flow_operations(self):
"""Test fetch_and_apply with mixed add/remove operations"""
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config removes flow1, keeps flow2, adds flow3
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow2": '{"name": "test_flow_2"}',
"flow3": '{"name": "test_flow_3"}'
"flow3": '{"name": "test_flow_3"}',
}
}
@ -358,20 +352,22 @@ class TestConfigReceiver:
start_calls = []
stop_calls = []
async def mock_start_flow(id, flow):
start_calls.append((id, flow))
async def mock_stop_flow(id, flow):
stop_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_calls.append((workspace, id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
ws_flows = config_receiver.flows["default"]
assert "flow1" not in ws_flows
assert "flow2" in ws_flows
assert "flow3" in ws_flows
assert len(start_calls) == 1
assert start_calls[0][0] == "flow3"
assert start_calls[0][:2] == ("default", "flow3")
assert len(stop_calls) == 1
assert stop_calls[0][0] == "flow1"
assert stop_calls[0][:2] == ("default", "flow1")

View file

@ -36,7 +36,6 @@ def _ge_response_dict():
"metadata": {
"id": "doc-1",
"root": "",
"user": "alice",
"collection": "testcoll",
},
"entities": [
@ -59,7 +58,6 @@ def _triples_response_dict():
"metadata": {
"id": "doc-1",
"root": "",
"user": "alice",
"collection": "testcoll",
},
"triples": [
@ -73,9 +71,9 @@ def _triples_response_dict():
}
def _make_request(id_="doc-1", user="alice"):
def _make_request(id_="doc-1", workspace="alice"):
request = Mock()
request.query = {"id": id_, "user": user}
request.query = {"id": id_, "workspace": workspace}
return request
@ -149,12 +147,8 @@ class TestCoreExportWireFormat:
msg_type, payload = items[0]
assert msg_type == "ge"
# Metadata envelope: only id/user/collection — no stale `m["m"]`.
assert payload["m"] == {
"i": "doc-1",
"u": "alice",
"c": "testcoll",
}
# Metadata envelope: only id/collection — no stale `m["m"]`.
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
# Entities: each carries the *singular* `v` and the term envelope
assert len(payload["e"]) == 2
@ -202,11 +196,7 @@ class TestCoreExportWireFormat:
msg_type, payload = items[0]
assert msg_type == "t"
assert payload["m"] == {
"i": "doc-1",
"u": "alice",
"c": "testcoll",
}
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
assert len(payload["t"]) == 1
@ -240,7 +230,7 @@ class TestCoreImportWireFormat:
payload = msgpack.packb((
"ge",
{
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
"m": {"i": "doc-1", "c": "testcoll"},
"e": [
{
"e": {"t": "i", "i": "http://example.org/alice"},
@ -266,7 +256,7 @@ class TestCoreImportWireFormat:
req = captured[0]
assert req["operation"] == "put-kg-core"
assert req["user"] == "alice"
assert req["workspace"] == "alice"
assert req["id"] == "doc-1"
ge = req["graph-embeddings"]
@ -275,7 +265,6 @@ class TestCoreImportWireFormat:
assert "metadata" not in ge["metadata"]
assert ge["metadata"] == {
"id": "doc-1",
"user": "alice",
"collection": "default",
}
@ -302,7 +291,7 @@ class TestCoreImportWireFormat:
payload = msgpack.packb((
"t",
{
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
"m": {"i": "doc-1", "c": "testcoll"},
"t": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
@ -407,11 +396,10 @@ class TestCoreImportExportRoundTrip:
original = _ge_response_dict()["graph-embeddings"]
ge = req["graph-embeddings"]
# The import side overrides id/user from the URL query (intentional),
# The import side overrides id from the URL query (intentional),
# so we only round-trip the entity payload itself.
assert ge["metadata"]["id"] == original["metadata"]["id"]
assert ge["metadata"]["user"] == original["metadata"]["user"]
assert len(ge["entities"]) == len(original["entities"])
for got, want in zip(ge["entities"], original["entities"]):
assert got["vector"] == want["vector"]

View file

@ -72,10 +72,10 @@ class TestDispatcherManager:
flow_data = {"name": "test_flow", "steps": []}
await manager.start_flow("flow1", flow_data)
assert "flow1" in manager.flows
assert manager.flows["flow1"] == flow_data
await manager.start_flow("default", "flow1", flow_data)
assert ("default", "flow1") in manager.flows
assert manager.flows[("default", "flow1")] == flow_data
@pytest.mark.asyncio
async def test_stop_flow(self):
@ -86,11 +86,11 @@ class TestDispatcherManager:
# Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []}
manager.flows["flow1"] = flow_data
await manager.stop_flow("flow1", flow_data)
assert "flow1" not in manager.flows
manager.flows[("default", "flow1")] = flow_data
await manager.stop_flow("default", "flow1", flow_data)
assert ("default", "flow1") not in manager.flows
def test_dispatch_global_service_returns_wrapper(self):
"""Test dispatch_global_service returns DispatcherWrapper"""
@ -275,12 +275,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -290,7 +290,7 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
params = {"flow": "test_flow", "kind": "triples"}
result = await manager.process_flow_import("ws", "running", params)
@ -326,12 +326,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
mock_dispatchers.__contains__.return_value = False
@ -348,12 +348,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -404,7 +404,7 @@ class TestDispatcherManager:
params = {"flow": "test_flow", "kind": "agent"}
result = await manager.process_flow_service("data", "responder", params)
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
assert result == "flow_result"
@pytest.mark.asyncio
@ -415,14 +415,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Add flow to the flows dictionary
manager.flows["test_flow"] = {"services": {"agent": {}}}
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
# Pre-populate with existing dispatcher
mock_dispatcher = Mock()
mock_dispatcher.process = AsyncMock(return_value="cached_result")
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
mock_dispatcher.process.assert_called_once_with("data", "responder")
assert result == "cached_result"
@ -435,7 +435,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -443,7 +443,7 @@ class TestDispatcherManager:
}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
@ -452,23 +452,23 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
request_queue="agent_request_queue",
response_queue="agent_response_queue",
timeout=120,
consumer="api-gateway-test_flow-agent-request",
subscriber="api-gateway-test_flow-agent-request"
consumer="api-gateway-default-test_flow-agent-request",
subscriber="api-gateway-default-test_flow-agent-request"
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
assert result == "new_result"
@pytest.mark.asyncio
@ -479,26 +479,26 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-load": {"flow": "text_load_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
mock_rr_dispatchers.__contains__.return_value = False
mock_sender_dispatchers.__contains__.return_value = True
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
mock_dispatcher.start = AsyncMock()
mock_dispatcher.process = AsyncMock(return_value="sender_result")
mock_dispatcher_class.return_value = mock_dispatcher
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
@ -506,9 +506,9 @@ class TestDispatcherManager:
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
assert result == "sender_result"
@pytest.mark.asyncio
@ -519,7 +519,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
@ -529,14 +529,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow without agent interface
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-completion": {"request": "req", "response": "resp"}
}
}
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_invalid_kind(self):
@ -546,7 +546,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow with interface but unsupported kind
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"invalid-kind": {"request": "req", "response": "resp"}
}
@ -558,7 +558,7 @@ class TestDispatcherManager:
mock_sender_dispatchers.__contains__.return_value = False
with pytest.raises(RuntimeError, match="Invalid kind"):
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
@pytest.mark.asyncio
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
@ -608,7 +608,7 @@ class TestDispatcherManager:
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -630,7 +630,7 @@ class TestDispatcherManager:
mock_rr_dispatchers.__contains__.return_value = True
results = await asyncio.gather(*[
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
for _ in range(5)
])
@ -638,5 +638,5 @@ class TestDispatcherManager:
"Dispatcher class instantiated more than once — duplicate consumer bug"
)
assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
assert all(r == "result" for r in results)

View file

@ -186,7 +186,6 @@ class TestEntityContextsImportMessageProcessing:
assert isinstance(sent, EntityContexts)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2

View file

@ -188,7 +188,6 @@ class TestGraphEmbeddingsImportMessageProcessing:
assert isinstance(sent, GraphEmbeddings)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2

View file

@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing:
# Check metadata
assert sent_object.metadata.id == "obj-123"
assert sent_object.metadata.user == "testuser"
assert sent_object.metadata.collection == "testcollection"
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')

View file

@ -23,7 +23,6 @@ class TestTextDocumentTranslator:
)
assert msg.metadata.id == "doc-1"
assert msg.metadata.user == "alice"
assert msg.metadata.collection == "research"
assert msg.text == payload.encode("utf-8")

View file

@ -29,10 +29,9 @@ class Triple:
self.o = o
class Metadata:
def __init__(self, id, user, collection, root=""):
def __init__(self, id, collection, root=""):
self.id = id
self.root = root
self.user = user
self.collection = collection
class Triples:
@ -108,7 +107,6 @@ def sample_triples(sample_triple):
"""Sample Triples batch object"""
metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -123,7 +121,6 @@ def sample_chunk():
"""Sample text chunk for processing"""
metadata = Metadata(
id="test-chunk-456",
user="test_user",
collection="test_collection",
)

View file

@ -322,7 +322,6 @@ This is not JSON at all
assert isinstance(sent_triples, Triples)
# Check metadata fields individually since implementation creates new Metadata object
assert sent_triples.metadata.id == sample_metadata.id
assert sent_triples.metadata.user == sample_metadata.user
assert sent_triples.metadata.collection == sample_metadata.collection
assert len(sent_triples.triples) == 1
assert sent_triples.triples[0].s.iri == "test:subject"
@ -346,7 +345,6 @@ This is not JSON at all
assert isinstance(sent_contexts, EntityContexts)
# Check metadata fields individually since implementation creates new Metadata object
assert sent_contexts.metadata.id == sample_metadata.id
assert sent_contexts.metadata.user == sample_metadata.user
assert sent_contexts.metadata.collection == sample_metadata.collection
assert len(sent_contexts.entities) == 1
assert sent_contexts.entities[0].entity.iri == "test:entity"

View file

@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic:
"""Test ExtractedObject creation and properties"""
# Arrange
metadata = Metadata(
id="test-extraction-001",
user="test_user",
id="test-extraction-001",
collection="test_collection",
)
@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic:
assert extracted_obj.values[0]["customer_id"] == "CUST001"
assert extracted_obj.confidence == 0.95
assert "John Doe" in extracted_obj.source_span
assert extracted_obj.metadata.user == "test_user"
def test_config_parsing_error_handling(self):
"""Test configuration parsing with invalid JSON"""

View file

@ -371,7 +371,6 @@ class TestTripleConstructionLogic:
metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -384,7 +383,6 @@ class TestTripleConstructionLogic:
# Assert
assert isinstance(triples_batch, Triples)
assert triples_batch.metadata.id == "test-doc-123"
assert triples_batch.metadata.user == "test_user"
assert triples_batch.metadata.collection == "test_collection"
assert len(triples_batch.triples) == 2

View file

@ -33,12 +33,12 @@ def _make_librarian(min_chunk_size=1):
def _make_doc_metadata(
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc"
):
meta = MagicMock()
meta.id = doc_id
meta.kind = kind
meta.user = user
meta.workspace = workspace
meta.title = title
meta.time = 1700000000
meta.comments = ""
@ -47,27 +47,27 @@ def _make_doc_metadata(
def _make_begin_request(
doc_id="doc-1", kind="application/pdf", user="alice",
doc_id="doc-1", kind="application/pdf", workspace="alice",
total_size=10_000_000, chunk_size=0
):
req = MagicMock()
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace)
req.total_size = total_size
req.chunk_size = chunk_size
return req
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"):
req = MagicMock()
req.upload_id = upload_id
req.chunk_index = chunk_index
req.user = user
req.workspace = workspace
req.content = base64.b64encode(content)
return req
def _make_session(
user="alice", total_chunks=5, chunk_size=2_000_000,
workspace="alice", total_chunks=5, chunk_size=2_000_000,
total_size=10_000_000, chunks_received=None, object_id="obj-1",
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
):
@ -76,11 +76,11 @@ def _make_session(
if document_metadata is None:
document_metadata = json.dumps({
"id": document_id, "kind": "application/pdf",
"user": user, "title": "Test", "time": 1700000000,
"workspace": workspace, "title": "Test", "time": 1700000000,
"comments": "", "tags": [],
})
return {
"user": user,
"workspace": workspace,
"total_chunks": total_chunks,
"chunk_size": chunk_size,
"total_size": total_size,
@ -259,10 +259,10 @@ class TestUploadChunk:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(user="bob")
req = _make_upload_chunk_request(workspace="bob")
with pytest.raises(RequestError, match="Not authorized"):
await lib.upload_chunk(req)
@ -353,7 +353,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.complete_upload(req)
@ -375,7 +375,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
await lib.complete_upload(req)
@ -394,7 +394,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="Missing chunks"):
await lib.complete_upload(req)
@ -406,7 +406,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.complete_upload(req)
@ -414,12 +414,12 @@ class TestCompleteUpload:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.complete_upload(req)
@ -439,7 +439,7 @@ class TestAbortUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.abort_upload(req)
@ -456,7 +456,7 @@ class TestAbortUpload:
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.abort_upload(req)
@ -464,12 +464,12 @@ class TestAbortUpload:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.abort_upload(req)
@ -492,7 +492,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -510,7 +510,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-expired"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -527,7 +527,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -539,12 +539,12 @@ class TestGetUploadStatus:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.get_upload_status(req)
@ -564,7 +564,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -587,7 +587,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -608,7 +608,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -630,7 +630,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x")
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 0 # Should use default 1MB
@ -649,7 +649,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=raw)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 1000
@ -666,7 +666,7 @@ class TestStreamDocument:
lib.blob_store.get_size = AsyncMock(return_value=5000)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 512
@ -698,7 +698,7 @@ class TestListUploads:
]
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
resp = await lib.list_uploads(req)
@ -713,7 +713,7 @@ class TestListUploads:
lib.table_store.list_upload_sessions.return_value = []
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
resp = await lib.list_uploads(req)

View file

@ -239,7 +239,7 @@ def _make_processor(tools=None):
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agent = agent
processor.agents = {"default": agent}
processor.aggregator = MagicMock()
return processor
@ -254,6 +254,7 @@ def _make_flow():
return producers[name]
flow = MagicMock(side_effect=factory)
flow.workspace = "default"
return flow
@ -299,7 +300,7 @@ class TestAgentReactDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
@ -344,7 +345,6 @@ class TestAgentReactDagStructure:
request1 = AgentRequest(
question="What is 6x7?",
user="testuser",
collection="default",
streaming=False,
session_id=session_id,
@ -433,7 +433,7 @@ class TestAgentPlanDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
@ -480,7 +480,6 @@ class TestAgentPlanDagStructure:
# Iteration 1: planning
request1 = AgentRequest(
question="Test?",
user="testuser",
collection="default",
streaming=False,
session_id=session_id,
@ -537,7 +536,7 @@ class TestAgentSupervisorDagStructure:
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
@ -563,7 +562,6 @@ class TestAgentSupervisorDagStructure:
request = AgentRequest(
question="Research quantum computing",
user="testuser",
collection="default",
streaming=False,
session_id=str(uuid.uuid4()),

View file

@ -31,7 +31,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
def mock_query_request(self):
"""Create a mock query request for testing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
@ -69,7 +68,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -83,7 +81,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
@ -101,7 +99,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_longer_vector(self, processor):
"""Test querying document embeddings with a longer vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3
@ -115,7 +112,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
@ -133,7 +130,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_with_limit(self, processor):
"""Test querying document embeddings respects limit parameter"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=2
@ -148,7 +144,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with the specified limit
processor.vecstore.search.assert_called_once_with(
@ -162,13 +158,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_vectors(self, processor):
"""Test querying document embeddings with empty vectors list"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[],
limit=5
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
@ -180,7 +175,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_search_results(self, processor):
"""Test querying document embeddings with empty search results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -189,7 +183,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called
processor.vecstore.search.assert_called_once_with(
@ -203,7 +197,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_unicode_documents(self, processor):
"""Test querying document embeddings with Unicode document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -217,7 +210,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3
@ -230,7 +223,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_large_documents(self, processor):
"""Test querying document embeddings with large document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -244,7 +236,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify large content is preserved in ChunkMatch objects
assert len(result) == 2
@ -256,7 +248,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_special_characters(self, processor):
"""Test querying document embeddings with special characters in documents"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -270,7 +261,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3
@ -283,13 +274,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying document embeddings with zero limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=0
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
@ -301,13 +291,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying document embeddings with negative limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=-1
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called (optimization for negative limit)
processor.vecstore.search.assert_not_called()
@ -319,7 +308,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -330,13 +318,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_document_embeddings(query)
await processor.query_document_embeddings('test_user', query)
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying document embeddings with different vector dimensions"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
limit=5
@ -349,7 +336,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with the vector
processor.vecstore.search.assert_called_once()
@ -364,7 +351,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_multiple_results(self, processor):
"""Test querying document embeddings with multiple results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -378,7 +364,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify results are ChunkMatch objects
assert len(result) == 3

View file

@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results1, mock_results2]
chunks = await processor.query_document_embeddings(mock_query_message)
chunks = await processor.query_document_embeddings('default', mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify limit is passed to query
mock_index.query.assert_called_once()
@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_results.matches = []
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify empty results
assert chunks == []
@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify Unicode content is properly handled
assert len(chunks) == 2
@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify large content is properly handled
assert len(chunks) == 1
@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify all content types are properly handled
assert len(chunks) == 5
@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_document_embeddings(message)
await processor.query_document_embeddings('test_user', message)
@pytest.mark.asyncio
async def test_query_document_embeddings_index_access_failure(self, processor):
@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
processor.pinecone.Index.side_effect = Exception("Index access failed")
with pytest.raises(Exception, match="Index access failed"):
await processor.query_document_embeddings(message)
await processor.query_document_embeddings('test_user', message)
@pytest.mark.asyncio
async def test_query_document_embeddings_vector_accumulation(self, processor):
@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify all queries were made
assert mock_index.query.call_count == 3

View file

@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'test_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('test_user', mock_message)
# Assert
# Verify query was called with correct parameters (with dimension suffix)
@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('multi_user', mock_message)
# Assert
# Verify query was called once
@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('limit_user', mock_message)
# Assert
# Verify query was called with exact limit (no multiplication)
@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('empty_user', mock_message)
# Assert
assert result == []
@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('dim_user', mock_message)
# Assert
# Verify query was called once with correct collection
@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'utf8_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('utf8_user', mock_message)
# Assert
assert len(result) == 2
@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_document_embeddings(mock_message)
await processor.query_document_embeddings('error_user', mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('zero_user', mock_message)
# Assert
# Should still query (with limit 0)
@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'large_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('large_user', mock_message)
# Assert
# Should query with full limit
@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
# This should raise a KeyError when trying to access payload['chunk_id']
with pytest.raises(KeyError):
await processor.query_document_embeddings(mock_message)
await processor.query_document_embeddings('payload_user', mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')

View file

@ -31,7 +31,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
def mock_query_request(self):
"""Create a mock query request for testing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
@ -117,7 +116,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -131,7 +129,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
@ -154,7 +152,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_multiple_results(self, processor):
"""Test querying graph embeddings returns multiple results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -168,7 +165,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
@ -186,7 +183,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_with_limit(self, processor):
"""Test querying graph embeddings respects limit parameter"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=2
@ -201,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with 2*limit for better deduplication
processor.vecstore.search.assert_called_once_with(
@ -215,7 +211,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_preserves_order(self, processor):
"""Test that query results preserve order from the vector store"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -229,7 +224,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify results are in the same order as returned by the store
assert len(result) == 3
@ -241,7 +236,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_results_limited(self, processor):
"""Test that results are properly limited when store returns more than requested"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2
@ -255,7 +249,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
@ -269,13 +263,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_empty_vectors(self, processor):
"""Test querying graph embeddings with empty vectors list"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[],
limit=5
)
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
@ -287,7 +280,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_empty_search_results(self, processor):
"""Test querying graph embeddings with empty search results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -296,7 +288,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called
processor.vecstore.search.assert_called_once_with(
@ -310,7 +302,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
"""Test querying graph embeddings with mixed URI and literal results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -325,7 +316,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify all results are properly typed
assert len(result) == 4
@ -348,7 +339,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -359,7 +349,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_graph_embeddings(query)
await processor.query_graph_embeddings('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
@ -430,13 +420,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying graph embeddings with zero limit"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=0
)
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
@ -448,7 +437,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_longer_vector(self, processor):
"""Test querying graph embeddings with a longer vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
limit=5
@ -461,7 +449,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once()

View file

@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message)
entities = await processor.query_graph_embeddings('default', mock_query_message)
# Verify query was made once
assert mock_index.query.call_count == 1
@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify limit is respected
assert len(entities) == 2
@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify correct index used for 2D vector
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_results.matches = []
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify empty results
assert entities == []
@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Should only return 2 entities (respecting limit)
mock_index.query.assert_called_once()
@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_graph_embeddings(message)
await processor.query_graph_embeddings('test_user', message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'test_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('test_user', mock_message)
# Assert
# Verify query was called with correct parameters (with dimension suffix)
@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('multi_user', mock_message)
# Assert
# Verify query was called once
@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('limit_user', mock_message)
# Assert
# Verify query was called with limit * 2
@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('empty_user', mock_message)
# Assert
assert result == []
@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('dim_user', mock_message)
# Assert
# Verify query was called once
@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'uri_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('uri_user', mock_message)
# Assert
assert len(result) == 3
@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_graph_embeddings(mock_message)
await processor.query_graph_embeddings('error_user', mock_message)
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('zero_user', mock_message)
# Assert
# Should still query (with limit 0)

View file

@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestMemgraphQueryUserCollectionIsolation:
class TestMemgraphQueryWorkspaceCollectionIsolation:
"""Test cases for Memgraph query service with user/collection isolation"""
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -22,7 +22,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -32,13 +31,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT 1000"
)
@ -48,14 +47,14 @@ class TestMemgraphQueryUserCollectionIsolation:
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -63,7 +62,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -73,13 +71,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT 1000"
)
@ -88,14 +86,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
async def test_so_query_with_workspace_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -103,7 +101,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -113,13 +110,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT 1000"
)
@ -128,14 +125,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -143,7 +140,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -153,13 +149,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 1000"
)
@ -167,14 +163,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
async def test_po_query_with_workspace_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -182,7 +178,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -192,13 +187,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT 1000"
)
@ -207,14 +202,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -222,7 +217,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -232,13 +226,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT 1000"
)
@ -246,14 +240,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -261,7 +255,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -271,13 +264,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 1000"
)
@ -285,14 +278,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -300,7 +293,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -310,36 +302,36 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@ -363,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples('default', query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
@ -383,7 +375,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -410,7 +401,7 @@ class TestMemgraphQueryUserCollectionIsolation:
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
result = await processor.query_triples("test_user", query)
# Verify results are proper Triple objects
assert len(result) == 2

View file

@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestNeo4jQueryUserCollectionIsolation:
class TestNeo4jQueryWorkspaceCollectionIsolation:
"""Test cases for Neo4j query service with user/collection isolation"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -22,7 +22,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -32,13 +31,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT 10"
)
@ -48,14 +47,14 @@ class TestNeo4jQueryUserCollectionIsolation:
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -63,7 +62,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -73,13 +71,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT 10"
)
@ -88,16 +86,16 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify SP query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT 10"
)
@ -106,14 +104,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_node_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
async def test_so_query_with_workspace_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -121,7 +119,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -131,13 +128,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT 10"
)
@ -146,14 +143,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -161,7 +158,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -171,13 +167,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
@ -185,14 +181,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
async def test_po_query_with_workspace_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -200,7 +196,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -210,13 +205,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT 10"
)
@ -225,14 +220,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -240,7 +235,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -250,13 +244,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT 10"
)
@ -264,14 +258,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -279,7 +273,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -289,13 +282,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 10"
)
@ -303,14 +296,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -318,7 +311,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -328,36 +320,36 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@ -381,7 +373,7 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples('default', query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
@ -401,7 +393,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -428,7 +419,7 @@ class TestNeo4jQueryUserCollectionIsolation:
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
result = await processor.query_triples("test_user", query)
# Verify results are proper Triple objects
assert len(result) == 2

View file

@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.schema_builders = {}
processor.graphql_schemas = {}
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.query_cassandra = MagicMock()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
}
# Process config
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert "customer" in processor.schemas["default"]
schema = processor.schemas["default"]["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
@ -147,39 +146,40 @@ class TestRowsGraphQLQueryLogic:
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
# Verify per-workspace schema builder was created and graphql schema built
assert "default" in processor.schema_builders
assert "default" in processor.graphql_schemas
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
"""Test GraphQL execution context setup"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
graphql_schema.execute.assert_called_once()
call_args = graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value']
assert context["processor"] == processor
assert context["user"] == "test_user"
assert context["workspace"] == "default"
assert context["collection"] == "test_collection"
# Verify result structure
@ -190,7 +190,8 @@ class TestRowsGraphQLQueryLogic:
async def test_error_handling_graphql_errors(self):
"""Test GraphQL error handling and conversion"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error
@ -212,13 +213,13 @@ class TestRowsGraphQLQueryLogic:
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
@ -248,7 +249,6 @@ class TestRowsGraphQLQueryLogic:
# Create mock message
mock_msg = MagicMock()
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ customers { id name } }',
variables={},
@ -259,6 +259,7 @@ class TestRowsGraphQLQueryLogic:
# Mock flow
mock_flow = MagicMock()
mock_flow.workspace = "default"
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
@ -267,10 +268,10 @@ class TestRowsGraphQLQueryLogic:
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
@ -297,7 +298,6 @@ class TestRowsGraphQLQueryLogic:
# Create mock message
mock_msg = MagicMock()
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ invalid_query }',
variables={},
@ -357,7 +357,7 @@ class TestUnifiedTableQueries:
# Query with filter on indexed field
results = await processor.query_cassandra(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="products",
row_schema=schema,
@ -374,7 +374,7 @@ class TestUnifiedTableQueries:
query = call_args[0][1]
params = call_args[0][2]
assert "SELECT data, source FROM test_user.rows" in query
assert "SELECT data, source FROM test_workspace.rows" in query
assert "collection = %s" in query
assert "schema_name = %s" in query
assert "index_name = %s" in query
@ -421,7 +421,7 @@ class TestUnifiedTableQueries:
# Query with filter on non-indexed field
results = await processor.query_cassandra(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="products",
row_schema=schema,

View file

@ -95,7 +95,6 @@ class TestCassandraQueryProcessor:
# Create query request with all SPO values
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -103,7 +102,7 @@ class TestCassandraQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify KnowledgeGraph was created with correct parameters
mock_kg_class.assert_called_once_with(
@ -170,7 +169,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -178,7 +176,7 @@ class TestCassandraQueryProcessor:
limit=50
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
@ -207,7 +205,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -215,7 +212,7 @@ class TestCassandraQueryProcessor:
limit=25
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
@ -244,7 +241,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=LITERAL, value='test_predicate'),
@ -252,7 +248,7 @@ class TestCassandraQueryProcessor:
limit=10
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
@ -281,7 +277,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -289,7 +284,7 @@ class TestCassandraQueryProcessor:
limit=75
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
@ -319,7 +314,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -327,7 +321,7 @@ class TestCassandraQueryProcessor:
limit=1000
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
@ -425,7 +419,6 @@ class TestCassandraQueryProcessor:
)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -433,7 +426,7 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query)
await processor.query_triples('test_user', query)
# Verify KnowledgeGraph was created with authentication
mock_kg_class.assert_called_once_with(
@ -463,7 +456,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -472,11 +464,11 @@ class TestCassandraQueryProcessor:
)
# First query should create TrustGraph
await processor.query_triples(query)
await processor.query_triples('test_user', query)
assert mock_kg_class.call_count == 1
# Second query with same table should reuse TrustGraph
await processor.query_triples(query)
await processor.query_triples('test_user', query)
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@ -504,7 +496,6 @@ class TestCassandraQueryProcessor:
# First query
query1 = TriplesQueryRequest(
user='user1',
collection='collection1',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -512,12 +503,11 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query1)
await processor.query_triples('user1', query1)
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
user='user2',
collection='collection2',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -525,7 +515,7 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query2)
await processor.query_triples('user2', query2)
assert processor.table == 'user2'
# Verify TrustGraph was created twice
@ -544,7 +534,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -553,7 +542,7 @@ class TestCassandraQueryProcessor:
)
with pytest.raises(Exception, match="Query failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -582,7 +571,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -590,7 +578,7 @@ class TestCassandraQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
assert len(result) == 2
assert result[0].o.value == 'object1'
@ -621,7 +609,6 @@ class TestCassandraQueryPerformanceOptimizations:
# PO query pattern (predicate + object, find subjects)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=LITERAL, value='test_predicate'),
@ -629,7 +616,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=50
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
@ -662,7 +649,6 @@ class TestCassandraQueryPerformanceOptimizations:
# OS query pattern (object + subject, find predicates)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -670,7 +656,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=25
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
@ -721,7 +707,6 @@ class TestCassandraQueryPerformanceOptimizations:
mock_tg_instance.reset_mock()
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value=s) if s else None,
p=Term(type=LITERAL, value=p) if p else None,
@ -729,7 +714,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=10
)
await processor.query_triples(query)
await processor.query_triples('test_user', query)
# Verify the correct method was called
method = getattr(mock_tg_instance, expected_method)
@ -780,7 +765,6 @@ class TestCassandraQueryPerformanceOptimizations:
# This is the query pattern that was slow with ALLOW FILTERING
query = TriplesQueryRequest(
user='large_dataset_user',
collection='massive_collection',
s=None,
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
@ -788,7 +772,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=1000
)
result = await processor.query_triples(query)
result = await processor.query_triples('large_dataset_user', query)
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.get_po.assert_called_once_with(

View file

@ -123,7 +123,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -131,7 +130,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -164,7 +163,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -172,7 +170,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -209,7 +207,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -217,7 +214,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -254,7 +251,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -262,7 +258,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -299,7 +295,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -307,7 +302,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -344,7 +339,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -352,7 +346,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -389,7 +383,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -397,7 +390,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -434,7 +427,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -442,7 +434,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -474,7 +466,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -484,7 +475,7 @@ class TestFalkorDBQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -122,7 +122,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -130,7 +129,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -164,7 +163,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -172,7 +170,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -210,7 +208,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -218,7 +215,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -256,7 +253,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -264,7 +260,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -302,7 +298,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -310,7 +305,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -348,7 +343,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -356,7 +350,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -394,7 +388,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -402,7 +395,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -440,7 +433,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -448,7 +440,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -478,7 +470,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -488,7 +479,7 @@ class TestMemgraphQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -122,7 +122,6 @@ class TestNeo4jQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -130,7 +129,7 @@ class TestNeo4jQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -164,7 +163,6 @@ class TestNeo4jQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -172,7 +170,7 @@ class TestNeo4jQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -210,7 +208,6 @@ class TestNeo4jQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -218,7 +215,7 @@ class TestNeo4jQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -248,7 +245,6 @@ class TestNeo4jQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -258,7 +254,7 @@ class TestNeo4jQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -30,7 +30,7 @@ class TestDocumentMetadataTranslator:
"title": "Test Document",
"comments": "No comments",
"metadata": [],
"user": "alice",
"workspace": "alice",
"tags": ["finance", "q4"],
"parent-id": "doc-100",
"document-type": "page",
@ -40,14 +40,14 @@ class TestDocumentMetadataTranslator:
assert obj.time == 1710000000
assert obj.kind == "application/pdf"
assert obj.title == "Test Document"
assert obj.user == "alice"
assert obj.workspace == "alice"
assert obj.tags == ["finance", "q4"]
assert obj.parent_id == "doc-100"
assert obj.document_type == "page"
wire = self.tx.encode(obj)
assert wire["id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["workspace"] == "alice"
assert wire["parent-id"] == "doc-100"
assert wire["document-type"] == "page"
@ -80,10 +80,10 @@ class TestDocumentMetadataTranslator:
def test_falsy_fields_omitted_from_wire(self):
"""Empty string fields should be omitted from wire format."""
obj = DocumentMetadata(id="", time=0, user="")
obj = DocumentMetadata(id="", time=0, workspace="")
wire = self.tx.encode(obj)
assert "id" not in wire
assert "user" not in wire
assert "workspace" not in wire
# ---------------------------------------------------------------------------
@ -101,7 +101,7 @@ class TestProcessingMetadataTranslator:
"document-id": "doc-123",
"time": 1710000000,
"flow": "default",
"user": "alice",
"workspace": "alice",
"collection": "my-collection",
"tags": ["tag1"],
}
@ -109,20 +109,20 @@ class TestProcessingMetadataTranslator:
assert obj.id == "proc-1"
assert obj.document_id == "doc-123"
assert obj.flow == "default"
assert obj.user == "alice"
assert obj.workspace == "alice"
assert obj.collection == "my-collection"
assert obj.tags == ["tag1"]
wire = self.tx.encode(obj)
assert wire["id"] == "proc-1"
assert wire["document-id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["workspace"] == "alice"
assert wire["collection"] == "my-collection"
def test_missing_fields_use_defaults(self):
obj = self.tx.decode({})
assert obj.id is None
assert obj.user is None
assert obj.workspace is None
assert obj.collection is None
def test_tags_none_omitted(self):
@ -135,10 +135,10 @@ class TestProcessingMetadataTranslator:
wire = self.tx.encode(obj)
assert wire["tags"] == []
def test_user_and_collection_preserved(self):
def test_workspace_and_collection_preserved(self):
"""Core pipeline routing fields must survive round-trip."""
data = {"user": "bob", "collection": "research"}
data = {"workspace": "bob", "collection": "research"}
obj = self.tx.decode(data)
wire = self.tx.encode(obj)
assert wire["user"] == "bob"
assert wire["workspace"] == "bob"
assert wire["collection"] == "research"

View file

@ -61,7 +61,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
@ -69,7 +68,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = [] # Empty vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
# No upsert should be called
proc.qdrant.upsert.assert_not_called()
@ -83,7 +82,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
@ -91,7 +89,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = None # None vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -103,7 +101,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
@ -111,7 +108,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = [0.1, 0.2, 0.3]
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -124,7 +121,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
@ -132,7 +128,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = [0.1, 0.2, 0.3]
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
proc.qdrant.upsert.assert_called_once()
@pytest.mark.asyncio
@ -146,7 +142,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "alice"
msg.metadata.collection = "docs"
emb = MagicMock()
@ -154,7 +149,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = [0.0] * 384 # 384-dim vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("alice", msg)
call_args = proc.qdrant.upsert.call_args
assert "d_alice_docs_384" in call_args[1]["collection_name"]
@ -175,7 +170,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
@ -183,7 +177,7 @@ class TestGraphEmbeddingsNullProtection:
entity.vector = [0.1, 0.2, 0.3]
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -195,7 +189,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
@ -203,7 +196,7 @@ class TestGraphEmbeddingsNullProtection:
entity.vector = [0.1, 0.2, 0.3]
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -215,7 +208,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
@ -223,7 +215,7 @@ class TestGraphEmbeddingsNullProtection:
entity.vector = [] # Empty vector
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -236,7 +228,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
@ -245,7 +236,7 @@ class TestGraphEmbeddingsNullProtection:
entity.chunk_id = "c1"
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_called_once()
@pytest.mark.asyncio
@ -258,7 +249,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "alice"
msg.metadata.collection = "graphs"
entity = MagicMock()
@ -267,7 +257,7 @@ class TestGraphEmbeddingsNullProtection:
entity.chunk_id = ""
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("alice", msg)
# Collection should be created with correct dimension
proc.qdrant.create_collection.assert_called_once()
@ -290,11 +280,10 @@ class TestCollectionValidation:
proc.collection_exists = MagicMock(return_value=False)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "deleted-col"
msg.chunks = [MagicMock()]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -306,9 +295,8 @@ class TestCollectionValidation:
proc.collection_exists = MagicMock(return_value=False)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "deleted-col"
msg.entities = [MagicMock()]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()

View file

@ -92,14 +92,13 @@ class TestQuery:
# Initialize Query with defaults
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
assert query.collection == "test_collection"
assert query.verbose is False
assert query.doc_limit == 20 # Default value
@ -112,7 +111,7 @@ class TestQuery:
# Initialize Query with custom doc_limit
query = Query(
rag=mock_rag,
user="custom_user",
workspace="test_workspace",
collection="custom_collection",
verbose=True,
doc_limit=50
@ -120,7 +119,6 @@ class TestQuery:
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.doc_limit == 50
@ -137,7 +135,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -162,7 +160,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -184,7 +182,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -223,7 +221,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=15
@ -240,7 +238,6 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
limit=15,
user="test_user",
collection="test_collection"
)
@ -286,7 +283,6 @@ class TestQuery:
result = await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=10
)
@ -304,7 +300,6 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
)
@ -350,7 +345,6 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
user="trustgraph", # Default user
collection="default" # Default collection
)
@ -380,7 +374,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=True,
doc_limit=5
@ -453,7 +447,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -509,7 +503,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=True
)
@ -558,7 +552,6 @@ class TestQuery:
result = await document_rag.query(
query=query_text,
user="research_user",
collection="ml_knowledge",
doc_limit=25
)
@ -619,7 +612,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=10

View file

@ -1,6 +1,6 @@
"""
Unit test for DocumentRAG service parameter passing fix.
Tests that user and collection parameters from the message are correctly
Tests that the collection parameter from the message is correctly
passed to the DocumentRag.query() method.
"""
@ -16,13 +16,13 @@ class TestDocumentRagService:
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
@pytest.mark.asyncio
async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class):
async def test_collection_parameter_passed_to_query(self, mock_document_rag_class):
"""
Test that user and collection from message are passed to DocumentRag.query().
This is a regression test for the bug where user/collection parameters
were ignored, causing wrong collection names like 'd_trustgraph_default_384'
instead of 'd_my_user_test_coll_1_384'.
Test that collection from message is passed to DocumentRag.query().
This is a regression test for the bug where the collection parameter
was ignored, causing wrong collection names like 'd_trustgraph_default_384'
instead of one that reflects the requested collection.
"""
# Setup processor
processor = Processor(
@ -30,17 +30,16 @@ class TestDocumentRagService:
id="test-processor",
doc_limit=10
)
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom user/collection
# Setup message with custom collection
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="test query",
user="my_user", # Custom user (not default "trustgraph")
collection="test_coll_1", # Custom collection (not default "default")
doc_limit=5
)
@ -64,7 +63,7 @@ class TestDocumentRagService:
# Verify: DocumentRag.query was called with correct parameters
mock_rag_instance.query.assert_called_once_with(
"test query",
user="my_user", # Must be from message, not hardcoded default
workspace=ANY, # Workspace comes from flow.workspace (mock)
collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5,
explain_callback=ANY, # Explainability callback is always passed
@ -103,7 +102,6 @@ class TestDocumentRagService:
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
doc_limit=10,
streaming=False # Non-streaming mode

View file

@ -78,14 +78,12 @@ class TestQuery:
# Initialize Query with defaults
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
assert query.collection == "test_collection"
assert query.verbose is False
assert query.entity_limit == 50 # Default value
@ -101,7 +99,6 @@ class TestQuery:
# Initialize Query with custom parameters
query = Query(
rag=mock_rag,
user="custom_user",
collection="custom_collection",
verbose=True,
entity_limit=100,
@ -112,7 +109,6 @@ class TestQuery:
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.entity_limit == 100
@ -133,7 +129,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -156,7 +151,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=True
)
@ -177,7 +171,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -201,7 +194,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -244,7 +236,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
entity_limit=25
@ -269,7 +260,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -277,7 +267,7 @@ class TestQuery:
result = await query.maybe_label("entity1")
assert result == "Entity One Label"
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
mock_cache.get.assert_called_once_with("test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
@ -295,7 +285,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -307,13 +296,12 @@ class TestQuery:
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection",
g=""
)
assert result == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity"
cache_key = "test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@pytest.mark.asyncio
@ -330,7 +318,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -342,13 +329,12 @@ class TestQuery:
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection",
g=""
)
assert result == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity"
cache_key = "test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
@ -375,7 +361,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
triple_limit=10
@ -388,15 +373,15 @@ class TestQuery:
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
expected_subgraph = {
@ -415,7 +400,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -435,7 +419,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_subgraph_size=2
@ -455,7 +438,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_path_length=1
@ -493,7 +475,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_subgraph_size=100
@ -601,7 +582,6 @@ class TestQuery:
try:
response = await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
entity_limit=25,
triple_limit=15,

View file

@ -120,7 +120,6 @@ class TestGraphRagServiceExplainTriples:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is quantum computing?",
user="trustgraph",
collection="default",
streaming=False,
)

View file

@ -52,7 +52,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
entity_limit=50,
triple_limit=30,
@ -123,7 +122,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
entity_limit=50,
triple_limit=30,
@ -190,7 +188,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="Test query",
user="trustgraph",
collection="default",
streaming=False
)

View file

@ -286,11 +286,11 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert
assert "test_schema" in processor.schemas
schema = processor.schemas["test_schema"]
assert "test_schema" in processor.schemas["default"]
schema = processor.schemas["default"]["test_schema"]
assert schema.name == "test_schema"
assert schema.description == "Test schema"
assert len(schema.fields) == 2
@ -308,10 +308,10 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert - bad schema should be ignored
assert "bad_schema" not in processor.schemas
assert "bad_schema" not in processor.schemas.get("default", {})
def test_processor_initialization(self, mock_pulsar_client):
"""Test processor initialization with correct specifications"""

View file

@ -101,7 +101,7 @@ def service(mock_schemas):
taskgroup=MagicMock(),
id="test-processor"
)
service.schemas = mock_schemas
service.schemas = {"default": dict(mock_schemas)}
return service
@ -109,6 +109,7 @@ def service(mock_schemas):
def mock_flow():
"""Create mock flow with prompt service"""
flow = MagicMock()
flow.workspace = "default"
prompt_request_flow = AsyncMock()
flow.return_value.request = prompt_request_flow
return flow, prompt_request_flow

View file

@ -44,7 +44,6 @@ class TestStructuredQueryProcessor:
# Arrange
request = StructuredQueryRequest(
question="Show me all customers from New York",
user="trustgraph",
collection="default"
)
@ -110,7 +109,6 @@ class TestStructuredQueryProcessor:
assert isinstance(objects_call_args, RowsQueryRequest)
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
assert objects_call_args.variables == {"state": "NY"}
assert objects_call_args.user == "trustgraph"
assert objects_call_args.collection == "default"
# Verify response

View file

@ -17,7 +17,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
@ -80,7 +79,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -89,7 +87,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify insert was called once for the single chunk with its vector
processor.vecstore.insert.assert_called_once_with(
@ -99,14 +97,14 @@ class TestMilvusDocEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('test_workspace', mock_message)
# Verify insert was called once per chunk with user/collection parameters
# Verify insert was called once per chunk with workspace/collection parameters
expected_calls = [
# Chunk 1 - single vector
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_workspace', 'test_collection'),
# Chunk 2 - single vector
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_workspace', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
@ -122,7 +120,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -131,7 +128,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called for empty chunk
processor.vecstore.insert.assert_not_called()
@ -141,7 +138,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with None chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -150,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Note: Implementation passes through None chunk_ids (only skips empty string "")
processor.vecstore.insert.assert_called_once_with(
@ -162,7 +158,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with mix of valid and empty chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings(
@ -179,7 +174,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [valid_chunk, empty_chunk, another_valid]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify valid chunks were inserted, empty string chunk was skipped
expected_calls = [
@ -200,11 +195,10 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@ -214,7 +208,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -223,7 +216,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@ -233,7 +226,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each chunk has a single vector of different dimensions
@ -251,7 +243,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk1, chunk2, chunk3]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify all vectors were inserted regardless of dimension with user/collection parameters
expected_calls = [
@ -273,7 +265,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with Unicode content in chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -282,7 +273,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify Unicode chunk_id was stored correctly with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -294,7 +285,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with long chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a long chunk_id
@ -305,7 +295,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify long chunk_id was inserted with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -317,7 +307,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with whitespace-only chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -326,7 +315,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify whitespace content was inserted (not filtered out) with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -343,25 +332,24 @@ class TestMilvusDocEmbeddingsStorageProcessor:
('test@domain.com', 'test-collection.v1'),
]
for user, collection in test_cases:
for workspace, collection in test_cases:
processor.vecstore.reset_mock() # Reset mock for each test case
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = user
message.metadata.collection = collection
chunk = ChunkEmbeddings(
chunk_id="Test content",
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify insert was called with the correct user/collection
await processor.store_document_embeddings(workspace, message)
# Verify insert was called with the correct workspace/collection
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Test content", user, collection
[0.1, 0.2, 0.3], "Test content", workspace, collection
)
@pytest.mark.asyncio
@ -370,7 +358,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Store embeddings for user1/collection1
message1 = MagicMock()
message1.metadata = MagicMock()
message1.metadata.user = 'user1'
message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings(
chunk_id="User1 content",
@ -381,7 +368,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Store embeddings for user2/collection2
message2 = MagicMock()
message2.metadata = MagicMock()
message2.metadata.user = 'user2'
message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings(
chunk_id="User2 content",
@ -389,8 +375,8 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message2.chunks = [chunk2]
await processor.store_document_embeddings(message1)
await processor.store_document_embeddings(message2)
await processor.store_document_embeddings('user1', message1)
await processor.store_document_embeddings('user2', message2)
# Verify both calls were made with correct parameters
expected_calls = [
@ -411,18 +397,17 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with special characters in user/collection names"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'user@domain.com' # Email-like user
message.metadata.collection = 'test-collection.v1' # Collection with special chars
chunk = ChunkEmbeddings(
chunk_id="Special chars test",
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify the exact user/collection strings are passed (sanitization happens in DocVectors)
await processor.store_document_embeddings('user@domain.com', message)
# Verify the exact workspace/collection strings are passed (sanitization happens in DocVectors)
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1'
)

View file

@ -21,7 +21,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
@ -120,7 +119,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -135,7 +133,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index name and operations (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -185,7 +183,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that writing to non-existent index creates it lazily"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -200,7 +197,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created with correct dimension
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -217,7 +214,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -229,7 +225,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for empty chunk
mock_index.upsert.assert_not_called()
@ -239,7 +235,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -251,7 +246,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for None chunk
mock_index.upsert.assert_not_called()
@ -261,7 +256,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with chunk that decodes to empty string"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -273,7 +267,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for empty decoded chunk
mock_index.upsert.assert_not_called()
@ -283,7 +277,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each chunk has a single vector of different dimensions
@ -325,14 +318,13 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
@ -343,7 +335,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -355,7 +346,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@ -365,7 +356,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that lazy creation happens when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -380,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created
processor.pinecone.create_index.assert_called_once()
@ -390,7 +380,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that lazy creation works correctly"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -405,7 +394,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created and used
processor.pinecone.create_index.assert_called_once()
@ -416,7 +405,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -430,7 +418,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify Unicode content was properly decoded and stored
call_args = mock_index.upsert.call_args
@ -442,7 +430,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
@ -458,7 +445,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify large content was stored
call_args = mock_index.upsert.call_args

View file

@ -84,7 +84,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with chunks and vectors
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_chunk = MagicMock()
@ -94,7 +93,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('test_user', mock_message)
# Assert
# Verify collection existence was checked (with dimension suffix)
@ -138,7 +137,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple chunks
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_chunk1 = MagicMock()
@ -152,7 +150,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('multi_user', mock_message)
# Assert
# Should be called twice (once per chunk)
@ -198,7 +196,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple chunks, each having a single vector
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_chunk1 = MagicMock()
@ -216,7 +213,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('vector_user', mock_message)
# Assert
# Should be called 3 times (once per chunk)
@ -255,7 +252,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with empty chunk_id
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock()
@ -265,7 +261,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk_empty]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('empty_user', mock_message)
# Assert
# Should not call upsert for empty chunk_ids
@ -298,7 +294,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock()
@ -308,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('new_user', mock_message)
# Assert - collection should be lazily created
expected_collection = 'd_new_user_new_collection_5' # 5 dimensions
@ -350,7 +345,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'error_user'
mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock()
@ -361,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Act & Assert - should propagate the creation error
with pytest.raises(Exception, match="Connection error"):
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('error_user', mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@ -388,7 +382,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create first mock message
mock_message1 = MagicMock()
mock_message1.metadata.user = 'cache_user'
mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock()
@ -398,7 +391,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message1.chunks = [mock_chunk1]
# First call
await processor.store_document_embeddings(mock_message1)
await processor.store_document_embeddings('cache_user', mock_message1)
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
@ -406,7 +399,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create second mock message with same dimensions
mock_message2 = MagicMock()
mock_message2.metadata.user = 'cache_user'
mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock()
@ -416,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message2.chunks = [mock_chunk2]
# Act - Second call with same collection
await processor.store_document_embeddings(mock_message2)
await processor.store_document_embeddings('cache_user', mock_message2)
# Assert
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
@ -452,7 +444,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with chunks of different dimensions
mock_message = MagicMock()
mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection'
mock_chunk1 = MagicMock()
@ -466,7 +457,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('dim_user', mock_message)
# Assert
# Should check existence of DIFFERENT collections for each dimension
@ -526,7 +517,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with URI-style chunk_id
mock_message = MagicMock()
mock_message.metadata.user = 'uri_user'
mock_message.metadata.collection = 'uri_collection'
mock_chunk = MagicMock()
@ -536,7 +526,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('uri_user', mock_message)
# Assert
# Verify the chunk_id was stored correctly

View file

@ -17,7 +17,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entities with embeddings
@ -80,7 +79,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -89,7 +87,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify insert was called once with the full vector
processor.vecstore.insert.assert_called_once()
@ -102,14 +100,14 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('test_workspace', mock_message)
# Verify insert was called once per entity with user/collection parameters
# Verify insert was called once per entity with workspace/collection parameters
expected_calls = [
# Entity 1 - single vector
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_workspace', 'test_collection'),
# Entity 2 - single vector
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], 'literal entity', 'test_workspace', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
@ -125,7 +123,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -134,7 +131,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called for empty entity
processor.vecstore.insert.assert_not_called()
@ -144,7 +141,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -153,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called for None entity
processor.vecstore.insert.assert_not_called()
@ -163,7 +159,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with mix of valid and invalid entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings(
@ -183,7 +178,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [valid_entity, empty_entity, none_entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify only valid entity was inserted with user/collection/chunk_id parameters
processor.vecstore.insert.assert_called_once_with(
@ -196,11 +191,10 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@ -210,7 +204,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -219,7 +212,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@ -229,7 +222,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each entity has a single vector of different dimensions
@ -247,7 +239,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity1, entity2, entity3]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
@ -267,7 +259,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for both URI and literal entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings(
@ -280,7 +271,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [uri_entity, literal_entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify both entities were inserted
expected_calls = [

View file

@ -21,7 +21,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entity embeddings (each entity has a single vector)
@ -124,7 +123,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -139,7 +137,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1']):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index name and operations (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -189,7 +187,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that writing to non-existent index creates it lazily"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -204,7 +201,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created with correct dimension
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -221,7 +218,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -233,7 +229,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called for empty entity
mock_index.upsert.assert_not_called()
@ -243,7 +239,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -255,7 +250,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called for None entity
mock_index.upsert.assert_not_called()
@ -265,7 +260,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each entity has a single vector of different dimensions
@ -288,7 +282,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify different indexes were used for different dimensions
index_calls = processor.pinecone.Index.call_args_list
@ -307,14 +301,13 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
@ -325,7 +318,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -337,7 +329,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@ -347,7 +339,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that lazy creation happens when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -362,7 +353,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created
processor.pinecone.create_index.assert_called_once()
@ -372,7 +363,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that lazy creation works correctly"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -387,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created and used
processor.pinecone.create_index.assert_called_once()

View file

@ -64,7 +64,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with entities and vectors
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_entity = MagicMock()
@ -75,7 +74,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('test_user', mock_message)
# Assert
# Verify collection existence was checked (with dimension suffix)
@ -118,7 +117,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple entities
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_entity1 = MagicMock()
@ -134,7 +132,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity1, mock_entity2]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('multi_user', mock_message)
# Assert
# Should be called twice (once per entity)
@ -179,7 +177,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with three entities
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_entity1 = MagicMock()
@ -200,7 +197,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('vector_user', mock_message)
# Assert
# Should be called 3 times (once per entity)
@ -238,7 +235,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with empty entity value
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_entity_empty = MagicMock()
@ -253,7 +249,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity_empty, mock_entity_none]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('empty_user', mock_message)
# Assert
# Should not call upsert for empty entities

View file

@ -1,5 +1,5 @@
"""
Tests for Memgraph user/collection isolation in storage service
Tests for Memgraph workspace/collection isolation in storage service.
"""
import pytest
@ -8,47 +8,45 @@ from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor
class TestMemgraphUserCollectionIsolation:
"""Test cases for Memgraph storage service with user/collection isolation"""
class TestMemgraphWorkspaceCollectionIsolation:
"""Test cases for Memgraph storage service with workspace/collection isolation"""
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
"""Test that storage creates both legacy and user/collection indexes"""
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
"""Test that storage creates both legacy and workspace/collection indexes"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=MagicMock())
# Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total)
# 4 legacy + 4 workspace/collection = 8 total
assert mock_session.run.call_count == 8
# Check some specific index creation calls
expected_calls = [
"CREATE INDEX ON :Node",
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)",
"CREATE INDEX ON :Node(user)",
"CREATE INDEX ON :Node(workspace)",
"CREATE INDEX ON :Node(collection)",
"CREATE INDEX ON :Literal(user)",
"CREATE INDEX ON :Literal(workspace)",
"CREATE INDEX ON :Literal(collection)"
]
for expected_call in expected_calls:
mock_session.run.assert_any_call(expected_call)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_user_collection(self, mock_graph_db):
"""Test that store_triples includes user/collection in all operations"""
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
"""Test that store_triples includes workspace/collection in all operations"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -58,45 +56,39 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
# Create mock triple with URI object
from trustgraph.schema import IRI
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = IRI
triple.o.iri = "http://example.com/object"
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify user/collection parameters were passed to all operations
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
# create_node (subject), create_node (object), relate_node = 3 calls
assert mock_driver.execute_query.call_count == 3
# Check that user and collection were included in all calls
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert 'user' in call_kwargs
assert 'collection' in call_kwargs
assert call_kwargs['user'] == "test_user"
assert call_kwargs['collection'] == "test_collection"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
assert kwargs['workspace'] == "test_workspace"
assert kwargs['collection'] == "test_collection"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
"""Test that defaults are used when user/collection not provided in metadata"""
async def test_store_triples_with_default_collection(self, mock_graph_db):
"""Test that default collection is used when not provided in metadata"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -106,157 +98,151 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
# Create mock triple
from trustgraph.schema import IRI, LITERAL
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "literal_value"
triple.o.is_uri = False
# Create mock message without user/collection metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = None
mock_message.metadata.collection = None
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("default", mock_message)
# Verify defaults were used
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert call_kwargs['user'] == "default"
assert call_kwargs['collection'] == "default"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
assert kwargs['workspace'] == "default"
assert kwargs['collection'] == "default"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_node_includes_user_collection(self, mock_graph_db):
"""Test that create_node includes user/collection properties"""
def test_create_node_includes_workspace_collection(self, mock_graph_db):
"""Test that create_node includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.create_node("http://example.com/node", "test_user", "test_collection")
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/node",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_literal_includes_user_collection(self, mock_graph_db):
"""Test that create_literal includes user/collection properties"""
def test_create_literal_includes_workspace_collection(self, mock_graph_db):
"""Test that create_literal includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.create_literal("test_value", "test_user", "test_collection")
processor.create_literal("test_value", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="test_value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_relate_node_includes_user_collection(self, mock_graph_db):
"""Test that relate_node includes user/collection properties"""
def test_relate_node_includes_workspace_collection(self, mock_graph_db):
"""Test that relate_node includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.relate_node(
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/predicate",
"http://example.com/object",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_relate_literal_includes_user_collection(self, mock_graph_db):
"""Test that relate_literal includes user/collection properties"""
def test_relate_literal_includes_workspace_collection(self, mock_graph_db):
"""Test that relate_literal includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.relate_literal(
"http://example.com/subject",
"http://example.com/predicate",
"literal_value",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal_value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@ -264,20 +250,15 @@ class TestMemgraphUserCollectionIsolation:
def test_add_args_includes_memgraph_parameters(self):
"""Test that add_args properly configures Memgraph-specific parameters"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added with Memgraph defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://memgraph:7687'
assert hasattr(args, 'username')
@ -288,19 +269,18 @@ class TestMemgraphUserCollectionIsolation:
assert args.database == 'memgraph'
class TestMemgraphUserCollectionRegression:
"""Regression tests to ensure user/collection isolation prevents data leakage"""
class TestMemgraphWorkspaceCollectionRegression:
"""Regression tests to ensure workspace/collection isolation prevents data leakage"""
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
"""Regression test: Ensure users cannot access each other's data"""
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
"""Regression test: Ensure workspaces cannot access each other's data"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -310,60 +290,55 @@ class TestMemgraphUserCollectionRegression:
processor = Processor(taskgroup=MagicMock())
# Store data for user1
from trustgraph.schema import IRI, LITERAL
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "user1_data"
triple.o.is_uri = False
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "ws1_data"
message_user1 = MagicMock()
message_user1.triples = [triple]
message_user1.metadata.user = "user1"
message_user1.metadata.collection = "collection1"
message_ws1 = MagicMock()
message_ws1.triples = [triple]
message_ws1.metadata.collection = "collection1"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
await processor.store_triples("workspace1", message_ws1)
# Verify that all storage operations included user1/collection1 parameters
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
if 'user' in call_kwargs:
assert call_kwargs['user'] == "user1"
assert call_kwargs['collection'] == "collection1"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
if 'workspace' in kwargs:
assert kwargs['workspace'] == "workspace1"
assert kwargs['collection'] == "collection1"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_same_uri_different_users(self, mock_graph_db):
"""Regression test: Same URI can exist for different users without conflict"""
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
"""Regression test: Same URI can exist in different workspaces without conflict"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Same URI for different users should create separate nodes
processor.create_node("http://example.com/same-uri", "user1", "collection1")
processor.create_node("http://example.com/same-uri", "user2", "collection2")
# Verify both calls were made with different user/collection parameters
calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls
call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1]
call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1]
assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1"
assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2"
# Both should have the same URI but different user/collection
assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri"
processor.create_node("http://example.com/same-uri", "workspace1", "collection1")
processor.create_node("http://example.com/same-uri", "workspace2", "collection2")
calls = mock_driver.execute_query.call_args_list[-2:]
k1 = calls[0].kwargs
k2 = calls[1].kwargs
assert k1['workspace'] == "workspace1" and k1['collection'] == "collection1"
assert k2['workspace'] == "workspace2" and k2['collection'] == "collection2"
assert k1['uri'] == k2['uri'] == "http://example.com/same-uri"

View file

@ -1,5 +1,5 @@
"""
Tests for Neo4j user/collection isolation in triples storage and query
Tests for Neo4j workspace/collection isolation in triples storage and query.
"""
import pytest
@ -11,468 +11,406 @@ from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
from trustgraph.schema import TriplesQueryRequest
class TestNeo4jUserCollectionIsolation:
"""Test cases for Neo4j user/collection isolation functionality"""
class TestNeo4jWorkspaceCollectionIsolation:
"""Test cases for Neo4j workspace/collection isolation functionality"""
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
"""Test that storage service creates compound indexes for user/collection"""
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
"""Test that storage service creates compound indexes for workspace/collection"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Verify both legacy and new compound indexes are created
expected_indexes = [
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
]
# Check that all expected indexes were created
for expected_query in expected_indexes:
mock_session.run.assert_any_call(expected_query)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_user_collection(self, mock_graph_db):
"""Test that triples are stored with user/collection properties"""
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
"""Test that triples are stored with workspace/collection properties"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create test message with user/collection metadata
metadata = Metadata(
id="test-id",
user="test_user",
collection="test_collection"
)
metadata = Metadata(id="test-id", collection="test_collection")
triple = Triple(
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal_value")
)
message = Triples(
metadata=metadata,
triples=[triple]
)
# Mock execute_query to return summaries
message = Triples(metadata=metadata, triples=[triple])
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify nodes and relationships were created with user/collection properties
await processor.store_triples("test_workspace", message)
expected_calls = [
call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
),
call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="literal_value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
),
call(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal_value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
)
]
for expected_call in expected_calls:
mock_driver.execute_query.assert_any_call(*expected_call.args, **expected_call.kwargs)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
"""Test that default user/collection are used when not provided"""
async def test_store_triples_with_default_collection(self, mock_graph_db):
"""Test that default collection is used when not provided"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create test message without user/collection
metadata = Metadata(id="test-id")
triple = Triple(
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=IRI, iri="http://example.com/object")
)
message = Triples(
metadata=metadata,
triples=[triple]
)
# Mock execute_query
message = Triples(metadata=metadata, triples=[triple])
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify defaults were used
await processor.store_triples("default", message)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject",
user="default",
workspace="default",
collection="default",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_filters_by_user_collection(self, mock_graph_db):
"""Test that query service filters results by user/collection"""
async def test_query_triples_filters_by_workspace_collection(self, mock_graph_db):
"""Test that query service filters results by workspace/collection"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create test query
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None
)
# Mock query results
mock_records = [
MagicMock(data=lambda: {"dest": "http://example.com/object1"}),
MagicMock(data=lambda: {"dest": "literal_value"})
]
mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify queries include user/collection filters
await processor.query_triples("test_workspace", query)
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest"
)
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest"
)
# Check that queries were executed with user/collection parameters
calls = mock_driver.execute_query.call_args_list
assert any(
expected_literal_query in str(call) and
"user='test_user'" in str(call) and
"collection='test_collection'" in str(call)
for call in calls
expected_literal_query in str(c) and
"workspace='test_workspace'" in str(c) and
"collection='test_collection'" in str(c)
for c in calls
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_with_default_user_collection(self, mock_graph_db):
"""Test that query service uses defaults when user/collection not provided"""
async def test_query_triples_with_default_collection(self, mock_graph_db):
"""Test that query service uses default collection when not provided"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create test query without user/collection
query = TriplesQueryRequest(
s=None,
p=None,
o=None
)
# Mock empty results
query = TriplesQueryRequest(s=None, p=None, o=None)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify defaults were used in queries
await processor.query_triples("default", query)
calls = mock_driver.execute_query.call_args_list
assert any(
"user='default'" in str(call) and "collection='default'" in str(call)
for call in calls
"workspace='default'" in str(c) and "collection='default'" in str(c)
for c in calls
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_data_isolation_between_users(self, mock_graph_db):
"""Test that data from different users is properly isolated"""
async def test_data_isolation_between_workspaces(self, mock_graph_db):
"""Test that data from different workspaces is properly isolated"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create messages for different users
message_user1 = Triples(
metadata=Metadata(user="user1", collection="coll1"),
message_ws1 = Triples(
metadata=Metadata(collection="coll1"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.com/user1/subject"),
s=Term(type=IRI, iri="http://example.com/ws1/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user1_data")
o=Term(type=LITERAL, value="ws1_data")
)
]
)
message_user2 = Triples(
metadata=Metadata(user="user2", collection="coll2"),
message_ws2 = Triples(
metadata=Metadata(collection="coll2"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.com/user2/subject"),
s=Term(type=IRI, iri="http://example.com/ws2/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user2_data")
o=Term(type=LITERAL, value="ws2_data")
)
]
)
# Mock execute_query
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
# Store data for both users
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify user1 data was stored with user1/coll1
await processor.store_triples("workspace1", message_ws1)
await processor.store_triples("workspace2", message_ws2)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value="user1_data",
user="user1",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="ws1_data",
workspace="workspace1",
collection="coll1",
database_='neo4j'
)
# Verify user2 data was stored with user2/coll2
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value="user2_data",
user="user2",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="ws2_data",
workspace="workspace2",
collection="coll2",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_respects_user_collection(self, mock_graph_db):
"""Test that wildcard queries still filter by user/collection"""
async def test_wildcard_query_respects_workspace_collection(self, mock_graph_db):
"""Test that wildcard queries still filter by workspace/collection"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create wildcard query (all nulls) with user/collection
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
o=None
s=None, p=None, o=None,
)
# Mock results
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify wildcard queries include user/collection filters
await processor.query_triples("test_workspace", query)
wildcard_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
)
calls = mock_driver.execute_query.call_args_list
assert any(
wildcard_query in str(call) and
"user='test_user'" in str(call) and
"collection='test_collection'" in str(call)
for call in calls
wildcard_query in str(c) and
"workspace='test_workspace'" in str(c) and
"collection='test_collection'" in str(c)
for c in calls
)
def test_add_args_includes_neo4j_parameters(self):
"""Test that add_args includes Neo4j-specific parameters"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
StorageProcessor.add_args(parser)
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert hasattr(args, 'username')
assert hasattr(args, 'password')
assert hasattr(args, 'database')
# Check defaults
assert args.graph_host == 'bolt://neo4j:7687'
assert args.username == 'neo4j'
assert args.password == 'password'
assert args.database == 'neo4j'
class TestNeo4jUserCollectionRegression:
"""Regression tests to ensure user/collection isolation prevents data leaks"""
class TestNeo4jWorkspaceCollectionRegression:
"""Regression tests to ensure workspace/collection isolation prevents data leaks"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
@pytest.mark.asyncio
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
"""
Regression test: Ensure user1 cannot access user2's data
This test guards against the bug where all users shared the same
Neo4j graph space, causing data contamination between users.
Regression test: Ensure workspace1 cannot access workspace2's data.
Guards against a bug where all data shared the same Neo4j graph
space, causing data contamination between workspaces.
"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# User1 queries for all triples
query_user1 = TriplesQueryRequest(
user="user1",
query_ws1 = TriplesQueryRequest(
collection="collection1",
s=None, p=None, o=None
)
# Mock that the database has data but none matching user1/collection1
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query_user1)
# Verify empty results (user1 cannot see other users' data)
result = await processor.query_triples("workspace1", query_ws1)
assert len(result) == 0
# Verify the query included user/collection filters
calls = mock_driver.execute_query.call_args_list
for call in calls:
query_str = str(call)
for c in calls:
query_str = str(c)
if "MATCH" in query_str:
assert "user: $user" in query_str or "user='user1'" in query_str
assert "workspace: $workspace" in query_str or "workspace='workspace1'" in query_str
assert "collection: $collection" in query_str or "collection='collection1'" in query_str
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_same_uri_different_users(self, mock_graph_db):
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
"""
Regression test: Same URI in different user contexts should create separate nodes
This ensures that http://example.com/entity for user1 is completely separate
from http://example.com/entity for user2.
Regression test: Same URI in different workspace contexts should create separate nodes.
Ensures http://example.com/entity in workspace1 is completely
separate from the same URI in workspace2.
"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Same URI for different users
shared_uri = "http://example.com/shared_entity"
message_user1 = Triples(
metadata=Metadata(user="user1", collection="coll1"),
message_ws1 = Triples(
metadata=Metadata(collection="coll1"),
triples=[
Triple(
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user1_value")
o=Term(type=LITERAL, value="ws1_value")
)
]
)
message_user2 = Triples(
metadata=Metadata(user="user2", collection="coll2"),
message_ws2 = Triples(
metadata=Metadata(collection="coll2"),
triples=[
Triple(
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user2_value")
o=Term(type=LITERAL, value="ws2_value")
)
]
)
# Mock execute_query
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify two separate nodes were created with same URI but different user/collection
user1_node_call = call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
await processor.store_triples("workspace1", message_ws1)
await processor.store_triples("workspace2", message_ws2)
ws1_node_call = call(
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=shared_uri,
user="user1",
workspace="workspace1",
collection="coll1",
database_='neo4j'
)
user2_node_call = call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
ws2_node_call = call(
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=shared_uri,
user="user2",
workspace="workspace2",
collection="coll2",
database_='neo4j'
)
mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True)
mock_driver.execute_query.assert_has_calls([ws1_node_call, ws2_node_call], any_order=True)

View file

@ -1,3 +1,12 @@
def _flow_mock(workspace):
"""Build a mock flow object that is callable and exposes .workspace."""
from unittest.mock import MagicMock
f = MagicMock()
f.workspace = workspace
return f
"""
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
@ -92,13 +101,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
collection_name = processor.get_collection_name(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="customer_data",
dimension=384
)
assert collection_name == "rows_test_user_test_collection_customer_data_384"
assert collection_name == "rows_test_workspace_test_collection_customer_data_384"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
@ -185,11 +194,10 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
# Create embeddings message
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -210,14 +218,14 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
assert upsert_call_args[1]['collection_name'] == 'rows_test_workspace_test_collection_customers_3'
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
@ -243,10 +251,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -267,7 +274,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should be called once for the single embedding
assert mock_qdrant_instance.upsert.call_count == 1
@ -287,10 +294,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -311,7 +317,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should not call upsert for empty vectors
mock_qdrant_instance.upsert.assert_not_called()
@ -334,7 +340,6 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
@ -354,7 +359,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should not call upsert for unknown collection
mock_qdrant_instance.upsert.assert_not_called()
@ -368,11 +373,11 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
# Mock collections list
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
mock_coll1.name = 'rows_test_workspace_test_collection_schema1_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
mock_coll2.name = 'rows_test_workspace_test_collection_schema2_384'
mock_coll3 = MagicMock()
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
mock_coll3.name = 'rows_other_workspace_other_collection_schema_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
@ -386,15 +391,15 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
await processor.delete_collection('test_user', 'test_collection')
await processor.delete_collection('test_workspace', 'test_collection')
# Should delete only the matching collections
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):
@ -404,9 +409,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock()
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
mock_coll1.name = 'rows_test_workspace_test_collection_customers_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
mock_coll2.name = 'rows_test_workspace_test_collection_orders_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2]
@ -422,13 +427,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
await processor.delete_collection_schema(
'test_user', 'test_collection', 'customers'
'test_workspace', 'test_collection', 'customers'
)
# Should only delete the customers schema collection
mock_qdrant_instance.delete_collection.assert_called_once()
call_args = mock_qdrant_instance.delete_collection.call_args[0]
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
assert call_args[0] == 'rows_test_workspace_test_collection_customers_384'
if __name__ == '__main__':

View file

@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
class TestRowsCassandraStorageLogic:
"""Test business logic for unified table implementation"""
@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic:
}
# Process configuration
await processor.on_schema_config(config, version=1)
await processor.on_schema_config("default", config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert "customer_records" in processor.schemas["default"]
schema = processor.schemas["default"]["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
@ -165,16 +176,18 @@ class TestRowsCassandraStorageLogic:
"""Test that row processing stores data as map<text, text>"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
"default": {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -191,7 +204,6 @@ class TestRowsCassandraStorageLogic:
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
),
schema_name="test_schema",
@ -205,7 +217,7 @@ class TestRowsCassandraStorageLogic:
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify insert was executed
mock_async_execute.assert_called()
@ -214,7 +226,7 @@ class TestRowsCassandraStorageLogic:
values = insert_call[0][2]
# Verify using unified rows table
assert "INSERT INTO test_user.rows" in insert_cql
assert "INSERT INTO default.rows" in insert_cql
# Values should be: (collection, schema_name, index_name, index_value, data, source)
assert values[0] == "test_collection" # collection
@ -230,16 +242,18 @@ class TestRowsCassandraStorageLogic:
"""Test that row is written once per indexed field"""
processor = MagicMock()
processor.schemas = {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
"default": {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -255,7 +269,6 @@ class TestRowsCassandraStorageLogic:
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
),
schema_name="multi_index_schema",
@ -267,7 +280,7 @@ class TestRowsCassandraStorageLogic:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 inserts (one per indexed field: id, category, status)
assert mock_async_execute.call_count == 3
@ -290,15 +303,17 @@ class TestRowsCassandraStorageBatchLogic:
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
"default": {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -315,7 +330,6 @@ class TestRowsCassandraStorageBatchLogic:
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
),
schema_name="batch_schema",
@ -331,7 +345,7 @@ class TestRowsCassandraStorageBatchLogic:
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 inserts (one per row, one index per row since only primary key)
assert mock_async_execute.call_count == 3
@ -349,12 +363,14 @@ class TestRowsCassandraStorageBatchLogic:
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
"default": {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -369,7 +385,6 @@ class TestRowsCassandraStorageBatchLogic:
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
),
schema_name="empty_schema",
@ -381,7 +396,7 @@ class TestRowsCassandraStorageBatchLogic:
msg = MagicMock()
msg.value.return_value = empty_batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@ -446,19 +461,21 @@ class TestPartitionRegistration:
processor.registered_partitions = set()
processor.session = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
"default": {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
}
}
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
# Should have 2 inserts (one per index: id, category)
assert processor.session.execute.call_count == 2
@ -473,7 +490,7 @@ class TestPartitionRegistration:
processor.session = MagicMock()
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
# Should not execute any CQL since already registered
processor.session.execute.assert_not_called()

View file

@ -102,11 +102,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph was called with auth parameters
mock_kg_class.assert_called_once_with(
@ -129,11 +128,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user2'
mock_message.metadata.collection = 'collection2'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user2', mock_message)
# Verify KnowledgeGraph was called without auth parameters
mock_kg_class.assert_called_once_with(
@ -154,16 +152,15 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
# First call should create TrustGraph
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
assert mock_kg_class.call_count == 1
# Second call with same table should reuse TrustGraph
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@ -205,11 +202,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2
@ -234,11 +230,10 @@ class TestCassandraStorageProcessor:
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify no triples were inserted
mock_tg_instance.insert.assert_not_called()
@ -255,12 +250,11 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
@ -361,21 +355,19 @@ class TestCassandraStorageProcessor:
# First message with table1
mock_message1 = MagicMock()
mock_message1.metadata.user = 'user1'
mock_message1.metadata.collection = 'collection1'
mock_message1.triples = []
await processor.store_triples(mock_message1)
await processor.store_triples('user1', mock_message1)
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table
mock_message2 = MagicMock()
mock_message2.metadata.user = 'user2'
mock_message2.metadata.collection = 'collection2'
mock_message2.triples = []
await processor.store_triples(mock_message2)
await processor.store_triples('user2', mock_message2)
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
@ -407,11 +399,10 @@ class TestCassandraStorageProcessor:
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
await processor.store_triples('test_workspace', mock_message)
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
@ -440,12 +431,11 @@ class TestCassandraStorageProcessor:
mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
await processor.store_triples('new_user', mock_message)
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
@ -468,11 +458,10 @@ class TestCassandraPerformanceOptimizations:
processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph instance uses legacy mode
assert mock_tg_instance is not None
@ -489,11 +478,10 @@ class TestCassandraPerformanceOptimizations:
processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph instance is in optimized mode
assert mock_tg_instance is not None
@ -523,11 +511,10 @@ class TestCassandraPerformanceOptimizations:
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with(

View file

@ -17,7 +17,6 @@ class TestFalkorDBStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
@ -89,13 +88,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_node(test_uri, 'test_user', 'test_collection')
processor.create_node(test_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": test_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -109,13 +108,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_literal(test_value, 'test_user', 'test_collection')
processor.create_literal(test_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": test_value,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -132,17 +131,17 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection')
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src_uri,
"dest": dest_uri,
"uri": pred_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -159,17 +158,17 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection')
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src_uri,
"dest": literal_value,
"uri": pred_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -179,7 +178,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing triple with URI object"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple = Triple(
@ -200,21 +198,21 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create object node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create relationship
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
]
assert processor.io.query.call_count == 3
@ -237,21 +235,21 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('test_workspace', mock_message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create literal object
(("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",),
{"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",),
{"params": {"value": "literal object", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create relationship
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
]
assert processor.io.query.call_count == 3
@ -265,7 +263,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing multiple triples"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -291,7 +288,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
@ -313,7 +310,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing empty triples list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
@ -323,7 +319,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify no queries were made
processor.io.query.assert_not_called()
@ -333,7 +329,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing triples with mixed URI and literal objects"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -359,7 +354,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
@ -450,13 +445,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_node(test_uri, 'test_user', 'test_collection')
processor.create_node(test_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": test_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -470,13 +465,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_literal(test_value, 'test_user', 'test_collection')
processor.create_literal(test_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": test_value,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)

View file

@ -17,7 +17,6 @@ class TestMemgraphStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
@ -43,7 +42,7 @@ class TestMemgraphStorageProcessor:
taskgroup=MagicMock(),
id='test-memgraph-storage',
graph_host='bolt://localhost:7687',
username='test_user',
username='test_workspace',
password='test_pass',
database='test_db'
)
@ -105,9 +104,9 @@ class TestMemgraphStorageProcessor:
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)",
"CREATE INDEX ON :Node(user)",
"CREATE INDEX ON :Node(workspace)",
"CREATE INDEX ON :Node(collection)",
"CREATE INDEX ON :Literal(user)",
"CREATE INDEX ON :Literal(workspace)",
"CREATE INDEX ON :Literal(collection)"
]
@ -145,12 +144,12 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.create_node(test_uri, "test_user", "test_collection")
processor.create_node(test_uri, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=test_uri,
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_=processor.db
)
@ -166,12 +165,12 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.create_literal(test_value, "test_user", "test_collection")
processor.create_literal(test_value, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=test_value,
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_=processor.db
)
@ -190,14 +189,14 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri, "test_user", "test_collection")
processor.relate_node(src_uri, pred_uri, dest_uri, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src_uri, dest=dest_uri, uri=pred_uri,
user="test_user", collection="test_collection",
workspace="test_workspace", collection="test_collection",
database_=processor.db
)
@ -215,14 +214,14 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value, "test_user", "test_collection")
processor.relate_literal(src_uri, pred_uri, literal_value, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src_uri, dest=literal_value, uri=pred_uri,
user="test_user", collection="test_collection",
workspace="test_workspace", collection="test_collection",
database_=processor.db
)
@ -236,22 +235,22 @@ class TestMemgraphStorageProcessor:
o=Term(type=IRI, iri='http://example.com/object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create object node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/object', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create relationship
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate',
'user': 'test_user', 'collection': 'test_collection'})
'workspace': 'test_workspace', 'collection': 'test_collection'})
]
assert mock_tx.run.call_count == 3
@ -270,22 +269,22 @@ class TestMemgraphStorageProcessor:
o=Term(type=LITERAL, value='literal object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create literal object
("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
{'value': 'literal object', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
{'value': 'literal object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create relationship
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate',
'user': 'test_user', 'collection': 'test_collection'})
'workspace': 'test_workspace', 'collection': 'test_collection'})
]
assert mock_tx.run.call_count == 3
@ -314,8 +313,8 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('test_workspace', mock_message)
# Verify execute_query was called for create_node, create_literal, and relate_literal
# (since mock_message has a literal object)
assert processor.io.execute_query.call_count == 3
@ -323,7 +322,7 @@ class TestMemgraphStorageProcessor:
# Verify user/collection parameters were included
for call in processor.io.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert 'user' in call_kwargs
assert 'workspace' in call_kwargs
assert 'collection' in call_kwargs
@pytest.mark.asyncio
@ -343,7 +342,6 @@ class TestMemgraphStorageProcessor:
# Create message with multiple triples
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -364,7 +362,7 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify execute_query was called:
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
@ -375,7 +373,7 @@ class TestMemgraphStorageProcessor:
# Verify user/collection parameters were included in all calls
for call in processor.io.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert call_kwargs['user'] == 'test_user'
assert call_kwargs['workspace'] == 'test_workspace'
assert call_kwargs['collection'] == 'test_collection'
@pytest.mark.asyncio
@ -389,7 +387,6 @@ class TestMemgraphStorageProcessor:
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
@ -399,7 +396,7 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called()

View file

@ -68,9 +68,9 @@ class TestNeo4jStorageProcessor:
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
]
@ -116,12 +116,12 @@ class TestNeo4jStorageProcessor:
processor = Processor(taskgroup=taskgroup_mock)
# Test create_node
processor.create_node("http://example.com/node", "test_user", "test_collection")
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/node",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -146,12 +146,12 @@ class TestNeo4jStorageProcessor:
processor = Processor(taskgroup=taskgroup_mock)
# Test create_literal
processor.create_literal("literal value", "test_user", "test_collection")
processor.create_literal("literal value", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="literal value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -180,18 +180,18 @@ class TestNeo4jStorageProcessor:
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/object",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -220,18 +220,18 @@ class TestNeo4jStorageProcessor:
"http://example.com/subject",
"http://example.com/predicate",
"literal value",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -268,36 +268,35 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify create_node was called for subject and object
# Verify relate_node was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Object node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "http://example.com/object",
"uri": "http://example.com/predicate",
"user": "test_user",
"workspace": "test_workspace",
"collection": "test_collection",
"database_": "neo4j"
}
@ -340,12 +339,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify create_node was called for subject
# Verify create_literal was called for object
@ -353,24 +351,24 @@ class TestNeo4jStorageProcessor:
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Literal creation
(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
{"value": "literal value", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
{"value": "literal value", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "literal value",
"uri": "http://example.com/predicate",
"user": "test_user",
"workspace": "test_workspace",
"collection": "test_collection",
"database_": "neo4j"
}
@ -421,12 +419,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple1, triple2]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls
@ -449,12 +446,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with empty triples and metadata
mock_message = MagicMock()
mock_message.triples = []
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization
@ -568,38 +564,37 @@ class TestNeo4jStorageProcessor:
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject with spaces",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value='literal with "quotes" and unicode: ñáéíóú',
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject with spaces",
dest='literal with "quotes" and unicode: ñáéíóú',
uri="http://example.com/predicate:with/symbols",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)

View file

@ -24,11 +24,10 @@ def _make_processor(qdrant_client=None):
return proc
def _make_request(vector=None, user="test-user", collection="test-col",
def _make_request(vector=None, collection="test-col",
schema_name="customers", limit=10, index_name=None):
return RowEmbeddingsRequest(
vector=vector or [0.1, 0.2, 0.3],
user=user,
collection=collection,
schema_name=schema_name,
limit=limit,
@ -36,6 +35,14 @@ def _make_request(vector=None, user="test-user", collection="test-col",
)
def _make_flow(workspace="test-workspace", pub=None):
"""Make a mock flow object that is callable and has .workspace."""
flow = MagicMock()
flow.return_value = pub if pub is not None else AsyncMock()
flow.workspace = workspace
return flow
def _make_search_point(index_name, index_value, text, score):
point = MagicMock()
point.payload = {
@ -85,34 +92,33 @@ class TestFindCollection:
def test_finds_matching_collection(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_test_user_test_col_customers_384"
mock_coll.name = "rows_test_workspace_test_col_customers_384"
mock_collections = MagicMock()
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers")
result = proc.find_collection("test-workspace", "test-col", "customers")
# Prefix: rows_test_user_test_col_customers_
assert result == "rows_test_user_test_col_customers_384"
assert result == "rows_test_workspace_test_col_customers_384"
def test_returns_none_when_no_match(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_other_user_other_col_schema_768"
mock_coll.name = "rows_other_workspace_other_col_schema_768"
mock_collections = MagicMock()
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers")
result = proc.find_collection("test-workspace", "test-col", "customers")
assert result is None
def test_returns_none_on_error(self):
proc = _make_processor()
proc.qdrant.get_collections.side_effect = Exception("connection error")
result = proc.find_collection("user", "col", "schema")
result = proc.find_collection("workspace", "col", "schema")
assert result is None
@ -127,7 +133,7 @@ class TestQueryRowEmbeddings:
proc = _make_processor()
request = _make_request(vector=[])
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert result == []
@pytest.mark.asyncio
@ -136,13 +142,13 @@ class TestQueryRowEmbeddings:
proc.find_collection = MagicMock(return_value=None)
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert result == []
@pytest.mark.asyncio
async def test_successful_query_returns_matches(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
points = [
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
@ -153,7 +159,7 @@ class TestQueryRowEmbeddings:
proc.qdrant.query_points.return_value = mock_result
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert len(result) == 2
assert isinstance(result[0], RowIndexMatch)
@ -166,14 +172,14 @@ class TestQueryRowEmbeddings:
async def test_index_name_filter_applied(self):
"""When index_name is specified, a Qdrant filter should be used."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="address")
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is not None
@ -182,14 +188,14 @@ class TestQueryRowEmbeddings:
async def test_no_index_name_no_filter(self):
"""When index_name is empty, no filter should be applied."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="")
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is None
@ -198,7 +204,7 @@ class TestQueryRowEmbeddings:
async def test_missing_payload_fields_default(self):
"""Points with missing payload fields should use defaults."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
point = MagicMock()
point.payload = {} # Empty payload
@ -209,7 +215,7 @@ class TestQueryRowEmbeddings:
proc.qdrant.query_points.return_value = mock_result
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert len(result) == 1
assert result[0].index_name == ""
@ -219,13 +225,13 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_qdrant_error_propagates(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.qdrant.query_points.side_effect = Exception("qdrant down")
request = _make_request()
with pytest.raises(Exception, match="qdrant down"):
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
# ---------------------------------------------------------------------------
@ -243,7 +249,7 @@ class TestOnMessage:
])
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()
@ -264,7 +270,7 @@ class TestOnMessage:
)
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()
@ -284,7 +290,7 @@ class TestOnMessage:
proc.query_row_embeddings = AsyncMock(return_value=[])
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()

View file

@ -45,12 +45,9 @@ class TestGetGraphEmbeddings:
with `vector=` (singular) the schema field name. A previous
version used `vectors=` and TypeError'd at runtime.
"""
# Arrange — fake row matching the get_triples_stmt result shape:
# row[0..2] are unused by the method, row[3] is the entities blob
fake_row = (
None, None, None,
[
# ((value, is_uri), vector)
(("http://example.org/alice", True), [0.1, 0.2, 0.3]),
(("http://example.org/bob", True), [0.4, 0.5, 0.6]),
(("a literal entity", False), [0.7, 0.8, 0.9]),
@ -67,14 +64,8 @@ class TestGetGraphEmbeddings:
async def receiver(msg):
received.append(msg)
# Act
await store.get_graph_embeddings(
user="alice",
document_id="doc-1",
receiver=receiver,
)
await store.get_graph_embeddings("alice", "doc-1", receiver)
# Assert
mock_async_execute.assert_called_once_with(
store.cassandra,
store.get_graph_embeddings_stmt,
@ -86,7 +77,6 @@ class TestGetGraphEmbeddings:
assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert len(ge.entities) == 3
assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
@ -122,7 +112,7 @@ class TestGetGraphEmbeddings:
async def receiver(msg):
received.append(msg)
await store.get_graph_embeddings("u", "d", receiver)
await store.get_graph_embeddings("w", "d", receiver)
assert len(received) == 1
assert received[0].entities == []
@ -149,7 +139,7 @@ class TestGetGraphEmbeddings:
async def receiver(msg):
received.append(msg)
await store.get_graph_embeddings("u", "d", receiver)
await store.get_graph_embeddings("w", "d", receiver)
assert len(received) == 2
assert received[0].entities[0].entity.iri == "http://example.org/a"
@ -194,7 +184,6 @@ class TestGetTriples:
assert isinstance(triples_msg, Triples)
assert isinstance(triples_msg.metadata, Metadata)
assert triples_msg.metadata.id == "doc-1"
assert triples_msg.metadata.user == "alice"
assert len(triples_msg.triples) == 1
t = triples_msg.triples[0]

View file

@ -30,7 +30,6 @@ def sample():
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
chunks=[
@ -56,7 +55,6 @@ class TestDocumentEmbeddingsTranslator:
assert isinstance(decoded, DocumentEmbeddings)
assert isinstance(decoded.metadata, Metadata)
assert decoded.metadata.id == "doc-1"
assert decoded.metadata.user == "alice"
assert decoded.metadata.collection == "testcoll"
assert len(decoded.chunks) == 2

View file

@ -41,7 +41,7 @@ def translator():
def graph_embeddings_request():
return KnowledgeRequest(
operation="put-kg-core",
user="alice",
workspace="alice",
id="doc-1",
flow="default",
collection="testcoll",
@ -49,7 +49,6 @@ def graph_embeddings_request():
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
entities=[
@ -70,7 +69,6 @@ def graph_embeddings_request():
def triples_request():
return KnowledgeRequest(
operation="put-kg-core",
user="alice",
id="doc-1",
flow="default",
collection="testcoll",
@ -78,7 +76,6 @@ def triples_request():
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
triples=[
@ -113,7 +110,7 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings:
assert isinstance(decoded, KnowledgeRequest)
assert decoded.operation == "put-kg-core"
assert decoded.user == "alice"
assert decoded.workspace == "alice"
assert decoded.id == "doc-1"
assert decoded.flow == "default"
assert decoded.collection == "testcoll"
@ -123,7 +120,6 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings:
assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert ge.metadata.collection == "testcoll"
assert len(ge.entities) == 2
@ -143,7 +139,6 @@ class TestKnowledgeRequestTranslatorTriples:
assert decoded.triples is not None
assert isinstance(decoded.triples.metadata, Metadata)
assert decoded.triples.metadata.id == "doc-1"
assert decoded.triples.metadata.user == "alice"
assert decoded.triples.metadata.collection == "testcoll"
assert len(decoded.triples.triples) == 1