mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 18:36:22 +02:00
feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)
Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.
Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
captures the workspace/collection/flow hierarchy.
Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
service layer.
- Translators updated to not serialise/deserialise user.
API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.
Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
scoped by workspace. Config client API takes workspace as first
positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.
CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
library) drop user kwargs from every method signature.
MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
keyed per user.
Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
whose blueprint template was parameterised AND no remaining
live flow (across all workspaces) still resolves to that topic.
Three scopes fall out naturally from template analysis:
* {id} -> per-flow, deleted on stop
* {blueprint} -> per-blueprint, kept while any flow of the
same blueprint exists
* {workspace} -> per-workspace, kept while any flow in the
workspace exists
* literal -> global, never deleted (e.g. tg.request.librarian)
Fixes a bug where stopping a flow silently destroyed the global
librarian exchange, wedging all library operations until manual
restart.
RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
dead connections (broker restart, orphaned channels, network
partitions) within ~2 heartbeat windows, so the consumer
reconnects and re-binds its queue rather than sitting forever
on a zombie connection.
Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
This commit is contained in:
parent
9332089b3d
commit
d35473f7f7
377 changed files with 6868 additions and 5785 deletions
|
|
@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to call think and observe callbacks
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
|
@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
mock_agent_instance.tools = {}
|
||||
mock_agent_instance.additional_context = ""
|
||||
processor.agents["default"] = mock_agent_instance
|
||||
|
||||
# Mock react to return Final directly
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
|
|
@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
|
@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep
|
|||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(question="Test question", user="testuser",
|
||||
def _make_request(question="Test question",
|
||||
collection="default", streaming=False,
|
||||
session_id="parent-session", task_type="research",
|
||||
framing="test framing", conversation_id="conv-1"):
|
||||
return AgentRequest(
|
||||
question=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
|
|
@ -127,7 +126,6 @@ class TestBuildSynthesisRequest:
|
|||
req = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
|
|
@ -148,7 +146,7 @@ class TestBuildSynthesisRequest:
|
|||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# Last history step should be the synthesis step
|
||||
|
|
@ -168,7 +166,7 @@ class TestBuildSynthesisRequest:
|
|||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
|
||||
agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# Entry should be removed
|
||||
|
|
@ -178,7 +176,7 @@ class TestBuildSynthesisRequest:
|
|||
agg = Aggregator()
|
||||
with pytest.raises(RuntimeError, match="No results"):
|
||||
agg.build_synthesis_request(
|
||||
"unknown", "question", "user", "default",
|
||||
"unknown", "question", "default",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator
|
|||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
|
|
@ -130,7 +129,6 @@ class TestAggregatorIntegration:
|
|||
synth = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
|
|
@ -160,7 +158,7 @@ class TestAggregatorIntegration:
|
|||
agg.record_completion("corr-1", "goal", "answer")
|
||||
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
"corr-1", "question", "default",
|
||||
)
|
||||
|
||||
# correlation_id must be empty so it's not intercepted
|
||||
|
|
|
|||
|
|
@ -126,7 +126,6 @@ def make_base_request(**kwargs):
|
|||
state="",
|
||||
group=[],
|
||||
history=[],
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id="test-session-123",
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class MockProcessor:
|
|||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
|
|
|
|||
|
|
@ -167,39 +167,28 @@ class TestToolServiceRequest:
|
|||
"""Test cases for tool service request format"""
|
||||
|
||||
def test_request_format(self):
|
||||
"""Test that request is properly formatted with user, config, and arguments"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
"""Test that request is properly formatted with config and arguments"""
|
||||
config_values = {"style": "pun", "collection": "jokes"}
|
||||
arguments = {"topic": "programming"}
|
||||
|
||||
# Act - simulate request building
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values),
|
||||
"arguments": json.dumps(arguments)
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["user"] == "alice"
|
||||
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
|
||||
assert json.loads(request["arguments"]) == {"topic": "programming"}
|
||||
|
||||
def test_request_with_empty_config(self):
|
||||
"""Test request when no config values are provided"""
|
||||
# Arrange
|
||||
user = "bob"
|
||||
config_values = {}
|
||||
arguments = {"query": "test"}
|
||||
|
||||
# Act
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values) if config_values else "{}",
|
||||
"arguments": json.dumps(arguments) if arguments else "{}"
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["config"] == "{}"
|
||||
assert json.loads(request["arguments"]) == {"query": "test"}
|
||||
|
||||
|
|
@ -386,18 +375,13 @@ class TestJokeServiceLogic:
|
|||
assert map_topic_to_category("random topic") == "default"
|
||||
assert map_topic_to_category("") == "default"
|
||||
|
||||
def test_joke_response_personalization(self):
|
||||
"""Test that joke responses include user personalization"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
def test_joke_response_format(self):
|
||||
"""Test that joke response is formatted as expected"""
|
||||
style = "pun"
|
||||
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
|
||||
|
||||
# Act
|
||||
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
|
||||
response = f"Here's a {style} for you:\n\n{joke}"
|
||||
|
||||
# Assert
|
||||
assert "Hey alice!" in response
|
||||
assert "pun" in response
|
||||
assert joke in response
|
||||
|
||||
|
|
@ -439,20 +423,14 @@ class TestDynamicToolServiceBase:
|
|||
|
||||
def test_request_parsing(self):
|
||||
"""Test parsing of incoming request"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
"user": "alice",
|
||||
"config": '{"style": "pun"}',
|
||||
"arguments": '{"topic": "programming"}'
|
||||
}
|
||||
|
||||
# Act
|
||||
user = request_data.get("user", "trustgraph")
|
||||
config = json.loads(request_data["config"]) if request_data["config"] else {}
|
||||
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
|
||||
|
||||
# Assert
|
||||
assert user == "alice"
|
||||
assert config == {"style": "pun"}
|
||||
assert arguments == {"topic": "programming"}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Tests for tool service lifecycle, invoke contract, streaming responses,
|
||||
multi-tenancy, and error propagation.
|
||||
and error propagation.
|
||||
|
||||
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
|
||||
classes rather than plain dicts.
|
||||
|
|
@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await svc.invoke("user", {}, {})
|
||||
await svc.invoke({}, {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_calls_invoke_with_parsed_args(self):
|
||||
|
|
@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract:
|
|||
|
||||
calls = []
|
||||
|
||||
async def tracking_invoke(user, config, arguments):
|
||||
calls.append({"user": user, "config": config, "arguments": arguments})
|
||||
async def tracking_invoke(config, arguments):
|
||||
calls.append({"config": config, "arguments": arguments})
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking_invoke
|
||||
|
|
@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract:
|
|||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user="alice",
|
||||
config='{"style": "pun"}',
|
||||
arguments='{"topic": "cats"}',
|
||||
)
|
||||
|
|
@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract:
|
|||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["user"] == "alice"
|
||||
assert calls[0]["config"] == {"style": "pun"}
|
||||
assert calls[0]["arguments"] == {"topic": "cats"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_empty_user_defaults_to_trustgraph(self):
|
||||
"""Empty user field should default to 'trustgraph'."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
received_user = None
|
||||
|
||||
async def capture_invoke(user, config, arguments):
|
||||
nonlocal received_user
|
||||
received_user = user
|
||||
return "ok"
|
||||
|
||||
svc.invoke = capture_invoke
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
|
||||
msg.properties.return_value = {"id": "req-2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert received_user == "trustgraph"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_string_response_sent_directly(self):
|
||||
"""String return from invoke → response field is the string."""
|
||||
|
|
@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def string_invoke(user, config, arguments):
|
||||
async def string_invoke(config, arguments):
|
||||
return "hello world"
|
||||
|
||||
svc.invoke = string_invoke
|
||||
|
|
@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r1"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def dict_invoke(user, config, arguments):
|
||||
async def dict_invoke(config, arguments):
|
||||
return {"result": 42}
|
||||
|
||||
svc.invoke = dict_invoke
|
||||
|
|
@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def failing_invoke(user, config, arguments):
|
||||
async def failing_invoke(config, arguments):
|
||||
raise ValueError("bad input")
|
||||
|
||||
svc.invoke = failing_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r3"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def rate_limited_invoke(user, config, arguments):
|
||||
async def rate_limited_invoke(config, arguments):
|
||||
raise TooManyRequests("rate limited")
|
||||
|
||||
svc.invoke = rate_limited_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r4"}
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
|
|
@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def ok_invoke(user, config, arguments):
|
||||
async def ok_invoke(config, arguments):
|
||||
return "ok"
|
||||
|
||||
svc.invoke = ok_invoke
|
||||
|
|
@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract:
|
|||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "unique-42"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
|
@ -241,7 +210,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return "tool result"
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -260,6 +229,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
|
||||
|
|
@ -280,7 +250,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
async def mock_invoke(workspace, name, params):
|
||||
return {"data": [1, 2, 3]}
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
|
@ -298,6 +268,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -317,7 +288,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def failing_invoke(name, params):
|
||||
async def failing_invoke(workspace, name, params):
|
||||
raise RuntimeError("tool broke")
|
||||
|
||||
svc.invoke_tool = failing_invoke
|
||||
|
|
@ -330,6 +301,7 @@ class TestToolServiceOnRequest:
|
|||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
flow_callable.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
|
|
@ -350,7 +322,7 @@ class TestToolServiceOnRequest:
|
|||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def rate_limited(name, params):
|
||||
async def rate_limited(workspace, name, params):
|
||||
raise TooManyRequests("slow down")
|
||||
|
||||
svc.invoke_tool = rate_limited
|
||||
|
|
@ -362,6 +334,7 @@ class TestToolServiceOnRequest:
|
|||
flow = MagicMock()
|
||||
flow.producer = {"response": AsyncMock()}
|
||||
flow.name = "test-flow"
|
||||
flow.workspace = "default"
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await svc.on_request(msg, MagicMock(), flow)
|
||||
|
|
@ -376,7 +349,8 @@ class TestToolServiceOnRequest:
|
|||
|
||||
received = {}
|
||||
|
||||
async def capture_invoke(name, params):
|
||||
async def capture_invoke(workspace, name, params):
|
||||
received["workspace"] = workspace
|
||||
received["name"] = name
|
||||
received["params"] = params
|
||||
return "ok"
|
||||
|
|
@ -390,6 +364,7 @@ class TestToolServiceOnRequest:
|
|||
flow = lambda name: mock_pub
|
||||
flow.producer = {"response": mock_pub}
|
||||
flow.name = "f"
|
||||
flow.workspace = "default"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(
|
||||
|
|
@ -421,7 +396,6 @@ class TestToolServiceClientCall:
|
|||
))
|
||||
|
||||
result = await client.call(
|
||||
user="alice",
|
||||
config={"style": "pun"},
|
||||
arguments={"topic": "cats"},
|
||||
)
|
||||
|
|
@ -430,7 +404,6 @@ class TestToolServiceClientCall:
|
|||
|
||||
req = client.request.call_args[0][0]
|
||||
assert isinstance(req, ToolServiceRequest)
|
||||
assert req.user == "alice"
|
||||
assert json.loads(req.config) == {"style": "pun"}
|
||||
assert json.loads(req.arguments) == {"topic": "cats"}
|
||||
|
||||
|
|
@ -446,7 +419,7 @@ class TestToolServiceClientCall:
|
|||
))
|
||||
|
||||
with pytest.raises(RuntimeError, match="service down"):
|
||||
await client.call(user="u", config={}, arguments={})
|
||||
await client.call(config={}, arguments={})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_empty_config_sends_empty_json(self):
|
||||
|
|
@ -458,7 +431,7 @@ class TestToolServiceClientCall:
|
|||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config=None, arguments=None)
|
||||
await client.call(config=None, arguments=None)
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.config == "{}"
|
||||
|
|
@ -474,7 +447,7 @@ class TestToolServiceClientCall:
|
|||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config={}, arguments={}, timeout=30)
|
||||
await client.call(config={}, arguments={}, timeout=30)
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 30
|
||||
|
|
@ -509,7 +482,7 @@ class TestToolServiceClientStreaming:
|
|||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
assert result == "chunk1chunk2"
|
||||
|
|
@ -534,7 +507,7 @@ class TestToolServiceClientStreaming:
|
|||
|
||||
with pytest.raises(RuntimeError, match="stream failed"):
|
||||
await client.call_streaming(
|
||||
user="u", config={}, arguments={},
|
||||
config={}, arguments={},
|
||||
callback=AsyncMock(),
|
||||
)
|
||||
|
||||
|
|
@ -564,61 +537,9 @@ class TestToolServiceClientStreaming:
|
|||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
# Empty response is falsy, so callback shouldn't be called for it
|
||||
assert result == "data"
|
||||
assert received == ["data"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-tenancy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMultiTenancy:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_propagated_to_invoke(self):
|
||||
"""User from request should reach the invoke method."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
users_seen = []
|
||||
|
||||
async def tracking(user, config, arguments):
|
||||
users_seen.append(user)
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
for u in ["tenant-a", "tenant-b", "tenant-c"]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user=u, config="{}", arguments="{}",
|
||||
)
|
||||
msg.properties.return_value = {"id": f"req-{u}"}
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_sends_user_in_request(self):
|
||||
"""ToolServiceClient.call should include user in request."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="isolated-tenant", config={}, arguments={})
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.user == "isolated-tenant"
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
"""
|
||||
Tests for AsyncProcessor config notify pattern:
|
||||
- register_config_handler with types filtering
|
||||
- on_config_notify version comparison and type matching
|
||||
- fetch_config with short-lived client
|
||||
- fetch_and_apply_config retry logic
|
||||
- on_config_notify version comparison, type/workspace matching
|
||||
- fetch_and_apply_config retry logic over per-workspace fetches
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, Mock
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
# Patch heavy dependencies before importing AsyncProcessor
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
"""Create an AsyncProcessor with mocked dependencies."""
|
||||
|
|
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
|
|||
assert len(processor.config_handlers) == 2
|
||||
|
||||
|
||||
def _notify_msg(version, changes):
|
||||
"""Build a Mock config-notify message with given version and changes dict."""
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=version, changes=changes)
|
||||
return msg
|
||||
|
||||
|
||||
class TestOnConfigNotify:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -77,9 +81,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=3, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(3, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -91,9 +93,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=5, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(5, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -105,9 +105,7 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["schema"])
|
||||
|
||||
msg = _notify_msg(2, {"schema": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
|
@ -121,40 +119,36 @@ class TestOnConfigNotify:
|
|||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
# Mock fetch_config
|
||||
mock_config = {"prompt": {"key": "value"}}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={"key": "value"},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
handler.assert_called_once_with(
|
||||
"default", {"prompt": {"key": "value"}}, 2
|
||||
)
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_without_types_always_called(self, processor):
|
||||
async def test_handler_without_types_ignored_on_notify(self, processor):
|
||||
"""Handlers registered without types never fire on notifications."""
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler) # No types = all
|
||||
processor.register_config_handler(handler) # No types
|
||||
|
||||
mock_config = {"anything": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["whatever"])
|
||||
msg = _notify_msg(2, {"whatever": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
handler.assert_not_called()
|
||||
# Version still advances past the notify
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_handlers_type_filtering(self, processor):
|
||||
|
|
@ -168,156 +162,149 @@ class TestOnConfigNotify:
|
|||
processor.register_config_handler(schema_handler, types=["schema"])
|
||||
processor.register_config_handler(all_handler)
|
||||
|
||||
mock_config = {"prompt": {}}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
prompt_handler.assert_called_once()
|
||||
prompt_handler.assert_called_once_with(
|
||||
"default", {"prompt": {}}, 2
|
||||
)
|
||||
schema_handler.assert_not_called()
|
||||
all_handler.assert_called_once()
|
||||
all_handler.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_types_invokes_all(self, processor):
|
||||
"""Empty types list (startup signal) should invoke all handlers."""
|
||||
async def test_multi_workspace_notify_invokes_handler_per_ws(
|
||||
self, processor
|
||||
):
|
||||
"""Notify affecting multiple workspaces invokes handler once per workspace."""
|
||||
processor.config_version = 1
|
||||
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
mock_config = {}
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
return_value={},
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=[])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
h1.assert_called_once()
|
||||
h2.assert_called_once()
|
||||
assert handler.call_count == 2
|
||||
called_workspaces = {c.args[0] for c in handler.call_args_list}
|
||||
assert called_workspaces == {"ws1", "ws2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_failure_handled(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler)
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_workspace',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("Connection failed")
|
||||
side_effect=RuntimeError("Connection failed"),
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
msg = _notify_msg(2, {"prompt": ["default"]})
|
||||
# Should not raise
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
|
||||
class TestFetchConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_returns_config_and_version(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.config = {"prompt": {"key": "val"}}
|
||||
mock_resp.version = 42
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
config, version = await processor.fetch_config()
|
||||
|
||||
assert config == {"prompt": {"key": "val"}}
|
||||
assert version == 42
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_raises_on_error_response(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = Mock(message="not found")
|
||||
mock_resp.config = {}
|
||||
mock_resp.version = 0
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="Config error"):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_stops_client_on_exception(self, processor):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = TimeoutError("timeout")
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(TimeoutError):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestFetchAndApplyConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_applies_config_to_all_handlers(self, processor):
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
async def test_applies_config_per_workspace(self, processor):
|
||||
"""Startup fetch invokes handler once per workspace affected."""
|
||||
h = AsyncMock()
|
||||
processor.register_config_handler(h, types=["prompt"])
|
||||
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
return {
|
||||
"ws1": {"k": "v1"},
|
||||
"ws2": {"k": "v2"},
|
||||
}, 10
|
||||
|
||||
mock_config = {"prompt": {}, "schema": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 10)
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
# On startup, all handlers are invoked regardless of type
|
||||
h1.assert_called_once_with(mock_config, 10)
|
||||
h2.assert_called_once_with(mock_config, 10)
|
||||
assert h.call_count == 2
|
||||
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
|
||||
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
|
||||
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
|
||||
assert processor.config_version == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
call_count = 0
|
||||
mock_config = {"prompt": {}}
|
||||
async def test_handler_without_types_skipped_at_startup(self, processor):
|
||||
"""Handlers registered without types fetch nothing at startup."""
|
||||
typed = AsyncMock()
|
||||
untyped = AsyncMock()
|
||||
processor.register_config_handler(typed, types=["prompt"])
|
||||
processor.register_config_handler(untyped)
|
||||
|
||||
async def mock_fetch():
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
return {"default": {}}, 1
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
typed.assert_called_once()
|
||||
untyped.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
h = AsyncMock()
|
||||
processor.register_config_handler(h, types=["prompt"])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def fake_fetch_all(client, config_type):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError("not ready")
|
||||
return mock_config, 5
|
||||
return {"default": {"k": "v"}}, 5
|
||||
|
||||
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
|
||||
patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
mock_client = AsyncMock()
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
), patch.object(
|
||||
processor, '_fetch_type_all_workspaces',
|
||||
new=fake_fetch_all,
|
||||
), patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
assert call_count == 3
|
||||
assert processor.config_version == 5
|
||||
h.assert_called_once_with(
|
||||
"default", {"prompt": {"k": "v"}}, 5
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
result = await client.query(
|
||||
vector=vector,
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
timeout=30
|
||||
)
|
||||
|
|
@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
assert isinstance(call_args, DocumentEmbeddingsRequest)
|
||||
assert call_args.vector == vector
|
||||
assert call_args.limit == 10
|
||||
assert call_args.user == "test_user"
|
||||
assert call_args.collection == "test_collection"
|
||||
|
||||
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||
|
|
@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
client.request.assert_called_once()
|
||||
call_args = client.request.call_args[0][0]
|
||||
assert call_args.limit == 20 # Default limit
|
||||
assert call_args.user == "trustgraph" # Default user
|
||||
assert call_args.collection == "default" # Default collection
|
||||
|
||||
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||
|
|
|
|||
|
|
@ -40,10 +40,11 @@ def test_flow_initialization_calls_registered_specs():
|
|||
spec_two = MagicMock()
|
||||
processor = MagicMock(specifications=[spec_one, spec_two])
|
||||
|
||||
flow = Flow("processor-1", "flow-a", processor, {"answer": 42})
|
||||
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
|
||||
|
||||
assert flow.id == "processor-1"
|
||||
assert flow.name == "flow-a"
|
||||
assert flow.workspace == "default"
|
||||
assert flow.producer == {}
|
||||
assert flow.consumer == {}
|
||||
assert flow.parameter == {}
|
||||
|
|
@ -54,7 +55,7 @@ def test_flow_initialization_calls_registered_specs():
|
|||
def test_flow_start_and_stop_visit_all_consumers():
|
||||
consumer_one = AsyncMock()
|
||||
consumer_two = AsyncMock()
|
||||
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
|
||||
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
|
||||
flow.consumer = {"one": consumer_one, "two": consumer_two}
|
||||
|
||||
asyncio.run(flow.start())
|
||||
|
|
@ -67,7 +68,7 @@ def test_flow_start_and_stop_visit_all_consumers():
|
|||
|
||||
|
||||
def test_flow_call_returns_values_in_priority_order():
|
||||
flow = Flow("processor-1", "flow-a", MagicMock(specifications=[]), {})
|
||||
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
|
||||
flow.producer["shared"] = "producer-value"
|
||||
flow.consumer["consumer-only"] = "consumer-value"
|
||||
flow.consumer["shared"] = "consumer-value"
|
||||
|
|
|
|||
|
|
@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
|
|||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
await processor.start_flow("default", flow_name, flow_defn)
|
||||
|
||||
# Assert - Flow should be created with access to processor specifications
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
|
||||
|
||||
# The flow should have access to the processor's specifications
|
||||
# (The exact mechanism depends on Flow implementation)
|
||||
|
|
|
|||
|
|
@ -78,11 +78,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
await processor.start_flow("default", flow_name, flow_defn)
|
||||
|
||||
assert flow_name in processor.flows
|
||||
assert ("default", flow_name) in processor.flows
|
||||
mock_flow_class.assert_called_once_with(
|
||||
'test-processor', flow_name, processor, flow_defn
|
||||
'test-processor', flow_name, "default", processor, flow_defn
|
||||
)
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
|
|
@ -103,11 +103,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
await processor.start_flow(flow_name, {'config': 'test-config'})
|
||||
await processor.start_flow("default", flow_name, {'config': 'test-config'})
|
||||
|
||||
await processor.stop_flow(flow_name)
|
||||
await processor.stop_flow("default", flow_name)
|
||||
|
||||
assert flow_name not in processor.flows
|
||||
assert ("default", flow_name) not in processor.flows
|
||||
mock_flow.stop.assert_called_once()
|
||||
|
||||
@with_async_processor_patches
|
||||
|
|
@ -120,7 +120,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
await processor.stop_flow('non-existent-flow')
|
||||
await processor.stop_flow("default", 'non-existent-flow')
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -146,11 +146,11 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert 'test-flow' in processor.flows
|
||||
assert ("default", 'test-flow') in processor.flows
|
||||
mock_flow_class.assert_called_once_with(
|
||||
'test-processor', 'test-flow', processor,
|
||||
'test-processor', 'test-flow', "default", processor,
|
||||
{'config': 'test-config'}
|
||||
)
|
||||
mock_flow.start.assert_called_once()
|
||||
|
|
@ -171,7 +171,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -189,7 +189,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
'other-data': 'some-value'
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
await processor.on_configure_flows("default", config_data, version=1)
|
||||
|
||||
assert processor.flows == {}
|
||||
|
||||
|
|
@ -216,7 +216,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data1, version=1)
|
||||
await processor.on_configure_flows("default", config_data1, version=1)
|
||||
|
||||
config_data2 = {
|
||||
'processor:test-processor': {
|
||||
|
|
@ -224,12 +224,12 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data2, version=2)
|
||||
await processor.on_configure_flows("default", config_data2, version=2)
|
||||
|
||||
assert 'flow1' not in processor.flows
|
||||
assert ("default", 'flow1') not in processor.flows
|
||||
mock_flow1.stop.assert_called_once()
|
||||
|
||||
assert 'flow2' in processor.flows
|
||||
assert ("default", 'flow2') in processor.flows
|
||||
mock_flow2.start.assert_called_once()
|
||||
|
||||
@with_async_processor_patches
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ def sample_text_document():
|
|||
"""Sample document with moderate length text."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-1",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "The quick brown fox jumps over the lazy dog. " * 20
|
||||
|
|
@ -43,7 +42,6 @@ def long_text_document():
|
|||
"""Long document for testing multiple chunks."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-long",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
# Create a long text that will definitely be chunked
|
||||
|
|
@ -59,7 +57,6 @@ def unicode_text_document():
|
|||
"""Document with various unicode characters."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-unicode",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = """
|
||||
|
|
@ -84,7 +81,6 @@ def empty_text_document():
|
|||
"""Empty document for edge case testing."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-empty",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
return TextDocument(
|
||||
|
|
|
|||
|
|
@ -185,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content"
|
||||
|
|
|
|||
|
|
@ -185,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-456",
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content for token chunking"
|
||||
|
|
|
|||
|
|
@ -109,7 +109,8 @@ class TestListConfigItems:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
format_type='json',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_list_main_uses_defaults(self):
|
||||
|
|
@ -128,7 +129,8 @@ class TestListConfigItems:
|
|||
url='http://localhost:8088/',
|
||||
config_type='prompt',
|
||||
format_type='text',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -196,7 +198,8 @@ class TestGetConfigItem:
|
|||
config_type='prompt',
|
||||
key='template-1',
|
||||
format_type='json',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -253,7 +256,8 @@ class TestPutConfigItem:
|
|||
config_type='prompt',
|
||||
key='new-template',
|
||||
value='Custom prompt: {input}',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_put_main_with_stdin_arg(self):
|
||||
|
|
@ -278,7 +282,8 @@ class TestPutConfigItem:
|
|||
config_type='prompt',
|
||||
key='stdin-template',
|
||||
value=stdin_content,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_put_main_mutually_exclusive_args(self):
|
||||
|
|
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='old-template',
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def knowledge_loader():
|
|||
return KnowledgeLoader(
|
||||
files=["test.ttl"],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc-123",
|
||||
url="http://test.example.com/",
|
||||
|
|
@ -64,7 +64,7 @@ class TestKnowledgeLoader:
|
|||
loader = KnowledgeLoader(
|
||||
files=["file1.ttl", "file2.ttl"],
|
||||
flow="my-flow",
|
||||
user="user1",
|
||||
workspace="user1",
|
||||
collection="col1",
|
||||
document_id="doc1",
|
||||
url="http://example.com/",
|
||||
|
|
@ -73,7 +73,7 @@ class TestKnowledgeLoader:
|
|||
|
||||
assert loader.files == ["file1.ttl", "file2.ttl"]
|
||||
assert loader.flow == "my-flow"
|
||||
assert loader.user == "user1"
|
||||
assert loader.workspace == "user1"
|
||||
assert loader.collection == "col1"
|
||||
assert loader.document_id == "doc1"
|
||||
assert loader.url == "http://example.com/"
|
||||
|
|
@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob .
|
|||
loader = KnowledgeLoader(
|
||||
files=[f.name],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/"
|
||||
|
|
@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob .
|
|||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/",
|
||||
|
|
@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
|
|||
# Verify Api was created with correct parameters
|
||||
mock_api_class.assert_called_once_with(
|
||||
url="http://test.example.com/",
|
||||
token="test-token"
|
||||
token="test-token",
|
||||
workspace="test-user"
|
||||
)
|
||||
|
||||
# Verify bulk client was obtained
|
||||
|
|
@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob .
|
|||
call_args = mock_bulk.import_triples.call_args
|
||||
assert call_args[1]['flow'] == "test-flow"
|
||||
assert call_args[1]['metadata']['id'] == "test-doc"
|
||||
assert call_args[1]['metadata']['user'] == "test-user"
|
||||
assert call_args[1]['metadata']['collection'] == "test-collection"
|
||||
|
||||
# Verify import_entity_contexts was called
|
||||
|
|
@ -198,7 +198,7 @@ class TestCLIArgumentParsing:
|
|||
'tg-load-knowledge',
|
||||
'-i', 'doc-123',
|
||||
'-f', 'my-flow',
|
||||
'-U', 'my-user',
|
||||
'-w', 'my-user',
|
||||
'-C', 'my-collection',
|
||||
'-u', 'http://custom.example.com/',
|
||||
'-t', 'my-token',
|
||||
|
|
@ -216,7 +216,7 @@ class TestCLIArgumentParsing:
|
|||
token='my-token',
|
||||
flow='my-flow',
|
||||
files=['file1.ttl', 'file2.ttl'],
|
||||
user='my-user',
|
||||
workspace='my-user',
|
||||
collection='my-collection'
|
||||
)
|
||||
|
||||
|
|
@ -242,7 +242,7 @@ class TestCLIArgumentParsing:
|
|||
# Verify defaults were used
|
||||
call_args = mock_loader_class.call_args[1]
|
||||
assert call_args['flow'] == 'default'
|
||||
assert call_args['user'] == 'trustgraph'
|
||||
assert call_args['workspace'] == 'default'
|
||||
assert call_args['collection'] == 'default'
|
||||
assert call_args['url'] == 'http://localhost:8088/'
|
||||
assert call_args['token'] is None
|
||||
|
|
@ -287,7 +287,7 @@ class TestErrorHandling:
|
|||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
workspace="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/"
|
||||
|
|
|
|||
|
|
@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
|
|||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_set_main_structured_query_no_arguments_needed(self):
|
||||
|
|
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
|
|||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
token=None,
|
||||
workspace='default'
|
||||
)
|
||||
|
||||
def test_valid_types_includes_row_embeddings_query(self):
|
||||
|
|
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
|
|||
|
||||
show_main()
|
||||
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None)
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
|
||||
|
||||
|
||||
class TestShowToolsRowEmbeddingsQuery:
|
||||
|
|
|
|||
|
|
@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Act
|
||||
result = client.request(
|
||||
vector=vector,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
limit=10,
|
||||
timeout=300
|
||||
|
|
@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Assert
|
||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||
client.call.assert_called_once_with(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
vector=vector,
|
||||
limit=10,
|
||||
|
|
@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
# Assert
|
||||
assert result == ["test_chunk"]
|
||||
client.call.assert_called_once_with(
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
vector=vector,
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ def _make_query(
|
|||
|
||||
query = Query(
|
||||
rag=rag,
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
verbose=False,
|
||||
entity_limit=entity_limit,
|
||||
|
|
@ -208,7 +207,6 @@ class TestBatchTripleQueries:
|
|||
assert calls[0].kwargs["p"] is None
|
||||
assert calls[0].kwargs["o"] is None
|
||||
assert calls[0].kwargs["limit"] == 15
|
||||
assert calls[0].kwargs["user"] == "test-user"
|
||||
assert calls[0].kwargs["collection"] == "test-collection"
|
||||
assert calls[0].kwargs["batch_size"] == 20
|
||||
|
||||
|
|
|
|||
|
|
@ -28,10 +28,12 @@ def mock_flow_config():
|
|||
"""Mock flow configuration."""
|
||||
mock_config = Mock()
|
||||
mock_config.flows = {
|
||||
"test-flow": {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test-triples-queue"},
|
||||
"graph-embeddings-store": {"flow": "test-ge-queue"}
|
||||
"test-user": {
|
||||
"test-flow": {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test-triples-queue"},
|
||||
"graph-embeddings-store": {"flow": "test-ge-queue"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -43,7 +45,7 @@ def mock_flow_config():
|
|||
def mock_request():
|
||||
"""Mock knowledge load request."""
|
||||
request = Mock()
|
||||
request.user = "test-user"
|
||||
request.workspace = "test-user"
|
||||
request.id = "test-doc-id"
|
||||
request.collection = "test-collection"
|
||||
request.flow = "test-flow"
|
||||
|
|
@ -71,7 +73,6 @@ def sample_triples():
|
|||
return Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc-id",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
),
|
||||
triples=[
|
||||
|
|
@ -90,7 +91,6 @@ def sample_graph_embeddings():
|
|||
return GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="test-doc-id",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
),
|
||||
entities=[
|
||||
|
|
@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
|
|||
mock_triples_pub.send.assert_called_once()
|
||||
sent_triples = mock_triples_pub.send.call_args[0][1]
|
||||
assert sent_triples.metadata.collection == "test-collection"
|
||||
assert sent_triples.metadata.user == "test-user"
|
||||
assert sent_triples.metadata.id == "test-doc-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
|
|||
mock_ge_pub.send.assert_called_once()
|
||||
sent_ge = mock_ge_pub.send.call_args[0][1]
|
||||
assert sent_ge.metadata.collection == "test-collection"
|
||||
assert sent_ge.metadata.user == "test-user"
|
||||
assert sent_ge.metadata.id == "test-doc-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
"""Test that load_kg_core falls back to 'default' when request.collection is None."""
|
||||
# Create request with None collection
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
mock_request.collection = None # Should fall back to "default"
|
||||
mock_request.flow = "test-flow"
|
||||
|
|
@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
"""Test that load_kg_core validates flow configuration before processing."""
|
||||
# Request with invalid flow
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
mock_request.collection = "test-collection"
|
||||
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
|
||||
|
|
@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore:
|
|||
|
||||
# Test missing ID
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = None # Missing
|
||||
mock_request.collection = "test-collection"
|
||||
mock_request.flow = "test-flow"
|
||||
|
|
@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
|
||||
"""Test that get_kg_core preserves collection field from stored data."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
|
@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_list_kg_cores(self, knowledge_manager):
|
||||
"""Test listing knowledge cores."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
||||
|
|
@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods:
|
|||
async def test_delete_kg_core(self, knowledge_manager):
|
||||
"""Test deleting knowledge cores."""
|
||||
mock_request = Mock()
|
||||
mock_request.user = "test-user"
|
||||
mock_request.workspace = "test-user"
|
||||
mock_request.id = "test-doc-id"
|
||||
|
||||
mock_respond = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock message with inline data
|
||||
content = b"# Document Title\nBody text content."
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock message
|
||||
content = b"fake pdf"
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
]
|
||||
|
||||
content = b"fake pdf"
|
||||
mock_metadata = Metadata(id="test-doc", user="testuser",
|
||||
mock_metadata = Metadata(id="test-doc",
|
||||
collection="default")
|
||||
mock_document = Document(
|
||||
metadata=mock_metadata,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_basic(self):
|
||||
"""Test basic collection name creation"""
|
||||
result = make_safe_collection_name(
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -21,7 +21,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_special_characters(self):
|
||||
"""Test collection name creation with special characters that need sanitization"""
|
||||
result = make_safe_collection_name(
|
||||
user="user@domain.com",
|
||||
workspace="user@domain.com",
|
||||
collection="test-collection.v2",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -30,7 +30,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_unicode(self):
|
||||
"""Test collection name creation with Unicode characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="测试用户",
|
||||
workspace="测试用户",
|
||||
collection="colección_española",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -39,7 +39,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_spaces(self):
|
||||
"""Test collection name creation with spaces"""
|
||||
result = make_safe_collection_name(
|
||||
user="test user",
|
||||
workspace="test user",
|
||||
collection="my test collection",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -48,7 +48,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self):
|
||||
"""Test collection name creation with multiple consecutive special characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="user@@@domain!!!",
|
||||
workspace="user@@@domain!!!",
|
||||
collection="test---collection...v2",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -57,7 +57,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_with_leading_trailing_underscores(self):
|
||||
"""Test collection name creation with leading/trailing special characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="__test_user__",
|
||||
workspace="__test_user__",
|
||||
collection="@@test_collection##",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -66,7 +66,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_empty_user(self):
|
||||
"""Test collection name creation with empty user (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="",
|
||||
workspace="",
|
||||
collection="test_collection",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -75,7 +75,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_empty_collection(self):
|
||||
"""Test collection name creation with empty collection (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -84,7 +84,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_both_empty(self):
|
||||
"""Test collection name creation with both user and collection empty"""
|
||||
result = make_safe_collection_name(
|
||||
user="",
|
||||
workspace="",
|
||||
collection="",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -93,7 +93,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_only_special_characters(self):
|
||||
"""Test collection name creation with only special characters (should fallback to 'default')"""
|
||||
result = make_safe_collection_name(
|
||||
user="@@@!!!",
|
||||
workspace="@@@!!!",
|
||||
collection="---###",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -102,7 +102,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_whitespace_only(self):
|
||||
"""Test collection name creation with whitespace-only strings"""
|
||||
result = make_safe_collection_name(
|
||||
user=" \n\t ",
|
||||
workspace=" \n\t ",
|
||||
collection=" \r\n ",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -111,7 +111,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_mixed_valid_invalid_chars(self):
|
||||
"""Test collection name creation with mixed valid and invalid characters"""
|
||||
result = make_safe_collection_name(
|
||||
user="user123@test",
|
||||
workspace="user123@test",
|
||||
collection="coll_2023.v1",
|
||||
prefix="entity"
|
||||
)
|
||||
|
|
@ -147,7 +147,7 @@ class TestMilvusCollectionNaming:
|
|||
long_collection = "b" * 100
|
||||
|
||||
result = make_safe_collection_name(
|
||||
user=long_user,
|
||||
workspace=long_user,
|
||||
collection=long_collection,
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -159,7 +159,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_numeric_values(self):
|
||||
"""Test collection name creation with numeric user/collection values"""
|
||||
result = make_safe_collection_name(
|
||||
user="user123",
|
||||
workspace="user123",
|
||||
collection="collection456",
|
||||
prefix="doc"
|
||||
)
|
||||
|
|
@ -168,7 +168,7 @@ class TestMilvusCollectionNaming:
|
|||
def test_make_safe_collection_name_case_sensitivity(self):
|
||||
"""Test that collection name creation preserves case"""
|
||||
result = make_safe_collection_name(
|
||||
user="TestUser",
|
||||
workspace="TestUser",
|
||||
collection="TestCollection",
|
||||
prefix="Doc"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,9 +20,8 @@ def processor():
|
|||
)
|
||||
|
||||
|
||||
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
|
||||
user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
|
||||
metadata = Metadata(id=doc_id, collection=collection)
|
||||
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
|
|
@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor:
|
|||
@pytest.mark.asyncio
|
||||
async def test_metadata_preserved(self, processor):
|
||||
"""Output should carry the original metadata."""
|
||||
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
|
||||
msg = _make_chunk_message(collection="reports", doc_id="d1")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]]
|
||||
|
|
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
|
|||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "reports"
|
||||
assert result.metadata.id == "d1"
|
||||
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
|
|||
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
|
||||
|
||||
|
||||
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
def _make_message(entities, doc_id="doc-1", collection="default"):
|
||||
metadata = Metadata(id=doc_id, collection=collection)
|
||||
value = EntityContexts(metadata=metadata, entities=entities)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
|
|
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
|
|||
_make_entity_context(f"E{i}", f"ctx {i}")
|
||||
for i in range(5)
|
||||
]
|
||||
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
|
||||
msg = _make_message(entities, doc_id="doc-42", collection="main")
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
|
||||
mock_output = AsyncMock()
|
||||
|
|
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
|
|||
for call in mock_output.send.call_args_list:
|
||||
result = call[0][0]
|
||||
assert result.metadata.id == "doc-42"
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
await processor.on_schema_config("default", config_data, 1)
|
||||
|
||||
assert 'customers' in processor.schemas
|
||||
assert processor.schemas['customers'].name == 'customers'
|
||||
assert len(processor.schemas['customers'].fields) == 3
|
||||
assert 'customers' in processor.schemas["default"]
|
||||
assert processor.schemas["default"]['customers'].name == 'customers'
|
||||
assert len(processor.schemas["default"]['customers'].fields) == 3
|
||||
|
||||
async def test_on_schema_config_handles_missing_type(self):
|
||||
"""Test that missing schema type is handled gracefully"""
|
||||
|
|
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
'other_type': {}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
await processor.on_schema_config("default", config_data, 1)
|
||||
|
||||
assert processor.schemas == {}
|
||||
assert processor.schemas.get("default", {}) == {}
|
||||
|
||||
async def test_on_message_drops_unknown_collection(self):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
|
|
@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('default', 'test_collection')] = {}
|
||||
# No schemas registered
|
||||
|
||||
metadata = MagicMock()
|
||||
|
|
@ -322,17 +322,19 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('default', 'test_collection')] = {}
|
||||
|
||||
# Set up schema
|
||||
processor.schemas['customers'] = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
processor.schemas["default"] = {
|
||||
'customers': RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
|
|
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
return MagicMock()
|
||||
|
||||
mock_flow = MagicMock(side_effect=flow_factory)
|
||||
mock_flow.workspace = "default"
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,11 +34,10 @@ def _make_defn(entity, definition):
|
|||
return {"entity": entity, "definition": definition}
|
||||
|
||||
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
|
||||
user="user-1", collection="col-1", document_id=""):
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
|
||||
chunk = Chunk(
|
||||
metadata=Metadata(
|
||||
id=meta_id, root=root, user=user, collection=collection,
|
||||
id=meta_id, root=root, collection=collection,
|
||||
),
|
||||
chunk=text.encode("utf-8"),
|
||||
document_id=document_id,
|
||||
|
|
@ -229,8 +228,7 @@ class TestMetadataPreservation:
|
|||
defs = [_make_defn("X", "def X")]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-1", root="r-1",
|
||||
user="u-1", collection="coll-1",
|
||||
"text", meta_id="c-1", root="r-1", collection="coll-1",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
@ -238,7 +236,6 @@ class TestMetadataPreservation:
|
|||
for triples_msg in _sent_triples(triples_pub):
|
||||
assert triples_msg.metadata.id == "c-1"
|
||||
assert triples_msg.metadata.root == "r-1"
|
||||
assert triples_msg.metadata.user == "u-1"
|
||||
assert triples_msg.metadata.collection == "coll-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -247,8 +244,7 @@ class TestMetadataPreservation:
|
|||
defs = [_make_defn("X", "def X")]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-2", root="r-2",
|
||||
user="u-2", collection="coll-2",
|
||||
"text", meta_id="c-2", root="r-2", collection="coll-2",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
|
|||
|
|
@ -38,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
|
|||
}
|
||||
|
||||
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
|
||||
user="user-1", collection="col-1", document_id=""):
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
|
||||
"""Build a mock message wrapping a Chunk."""
|
||||
chunk = Chunk(
|
||||
metadata=Metadata(
|
||||
id=meta_id, root=root, user=user, collection=collection,
|
||||
id=meta_id, root=root, collection=collection,
|
||||
),
|
||||
chunk=text.encode("utf-8"),
|
||||
document_id=document_id,
|
||||
|
|
@ -189,8 +188,7 @@ class TestMetadataPreservation:
|
|||
rels = [_make_rel("X", "rel", "Y")]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-1", root="r-1",
|
||||
user="u-1", collection="coll-1",
|
||||
"text", meta_id="c-1", root="r-1", collection="coll-1",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
|
@ -198,7 +196,6 @@ class TestMetadataPreservation:
|
|||
for triples_msg in _sent_triples(pub):
|
||||
assert triples_msg.metadata.id == "c-1"
|
||||
assert triples_msg.metadata.root == "r-1"
|
||||
assert triples_msg.metadata.user == "u-1"
|
||||
assert triples_msg.metadata.collection == "coll-1"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
|
|||
ConfigReceiver.config_loader = Mock()
|
||||
|
||||
|
||||
def _notify(version, changes):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=version, changes=changes)
|
||||
return msg
|
||||
|
||||
|
||||
class TestConfigReceiver:
|
||||
"""Test cases for ConfigReceiver class"""
|
||||
|
||||
|
|
@ -47,98 +53,70 @@ class TestConfigReceiver:
|
|||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_new_version(self):
|
||||
"""Test on_config_notify triggers fetch for newer version"""
|
||||
async def test_on_config_notify_new_version_fetches_per_workspace(self):
|
||||
"""Notify with newer version fetches each affected workspace."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
# Mock fetch_and_apply
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with newer version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=["flow"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
assert len(fetch_calls) == 1
|
||||
msg = _notify(2, {"flow": ["ws1", "ws2"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert set(fetch_calls) == {"ws1", "ws2"}
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_old_version_ignored(self):
|
||||
"""Test on_config_notify ignores older versions"""
|
||||
"""Older-version notifies are ignored."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 5
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with older version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=3, types=["flow"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
assert len(fetch_calls) == 0
|
||||
msg = _notify(3, {"flow": ["ws1"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert fetch_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_irrelevant_types_ignored(self):
|
||||
"""Test on_config_notify ignores types the gateway doesn't care about"""
|
||||
"""Notifies without flow changes advance version but skip fetch."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with non-flow type
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
async def mock_fetch(workspace, retry=False):
|
||||
fetch_calls.append(workspace)
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
config_receiver.fetch_and_apply_workspace = mock_fetch
|
||||
|
||||
# Version should be updated but no fetch
|
||||
assert len(fetch_calls) == 0
|
||||
msg = _notify(2, {"prompt": ["ws1"]})
|
||||
await config_receiver.on_config_notify(msg, None, None)
|
||||
|
||||
assert fetch_calls == []
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_flow_type_triggers_fetch(self):
|
||||
"""Test on_config_notify fetches for flow-related types"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
for type_name in ["flow"]:
|
||||
fetch_calls.clear()
|
||||
config_receiver.config_version = 1
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=[type_name])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_exception_handling(self):
|
||||
"""Test on_config_notify handles exceptions gracefully"""
|
||||
"""on_config_notify swallows exceptions from message decode."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create notify message that causes an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
|
|
@ -146,19 +124,18 @@ class TestConfigReceiver:
|
|||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_new_flows(self):
|
||||
"""Test fetch_and_apply starts new flows"""
|
||||
async def test_fetch_and_apply_workspace_starts_new_flows(self):
|
||||
"""fetch_and_apply_workspace starts newly-configured flows."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock _create_config_client to return a mock client
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
"flow2": '{"name": "test_flow_2"}'
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -167,36 +144,39 @@ class TestConfigReceiver:
|
|||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
start_flow_calls = []
|
||||
async def mock_start_flow(id, flow):
|
||||
start_flow_calls.append((id, flow))
|
||||
|
||||
async def mock_start_flow(workspace, id, flow):
|
||||
start_flow_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert config_receiver.config_version == 5
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow1" in config_receiver.flows["default"]
|
||||
assert "flow2" in config_receiver.flows["default"]
|
||||
assert len(start_flow_calls) == 2
|
||||
assert all(c[0] == "default" for c in start_flow_calls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_removed_flows(self):
|
||||
"""Test fetch_and_apply stops removed flows"""
|
||||
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
|
||||
"""fetch_and_apply_workspace stops flows no longer configured."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
"default": {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"},
|
||||
}
|
||||
}
|
||||
|
||||
# Config now only has flow1
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}'
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -205,20 +185,22 @@ class TestConfigReceiver:
|
|||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
stop_flow_calls = []
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_flow_calls.append((id, flow))
|
||||
|
||||
async def mock_stop_flow(workspace, id, flow):
|
||||
stop_flow_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
assert "flow1" in config_receiver.flows["default"]
|
||||
assert "flow2" not in config_receiver.flows["default"]
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0][0] == "flow2"
|
||||
assert stop_flow_calls[0][:2] == ("default", "flow2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_no_flows(self):
|
||||
"""Test fetch_and_apply with empty config"""
|
||||
async def test_fetch_and_apply_workspace_with_no_flows(self):
|
||||
"""Empty workspace config clears any local flow state."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
|
@ -231,88 +213,100 @@ class TestConfigReceiver:
|
|||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.flows.get("default", {}) == {}
|
||||
assert config_receiver.config_version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
"""start_flow fans out to every registered flow handler."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler1.start_flow = AsyncMock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
handler2.start_flow = AsyncMock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
await config_receiver.start_flow("default", "flow1", flow_data)
|
||||
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler1.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
handler2.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handler_exception(self):
|
||||
"""Test start_flow method handles handler exceptions"""
|
||||
"""Handler exceptions in start_flow do not propagate."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
await config_receiver.start_flow("default", "flow1", flow_data)
|
||||
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler.start_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handlers(self):
|
||||
"""Test stop_flow method with multiple handlers"""
|
||||
"""stop_flow fans out to every registered flow handler."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler1.stop_flow = AsyncMock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
handler2.stop_flow = AsyncMock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
await config_receiver.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler1.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
handler2.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handler_exception(self):
|
||||
"""Test stop_flow method handles handler exceptions"""
|
||||
"""Handler exceptions in stop_flow do not propagate."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
await config_receiver.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler.stop_flow.assert_awaited_once_with(
|
||||
"default", "flow1", flow_data
|
||||
)
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -329,25 +323,25 @@ class TestConfigReceiver:
|
|||
mock_create_task.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_mixed_flow_operations(self):
|
||||
"""Test fetch_and_apply with mixed add/remove operations"""
|
||||
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
|
||||
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
"default": {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"},
|
||||
}
|
||||
}
|
||||
|
||||
# Config removes flow1, keeps flow2, adds flow3
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
"flow3": '{"name": "test_flow_3"}'
|
||||
"flow3": '{"name": "test_flow_3"}',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -358,20 +352,22 @@ class TestConfigReceiver:
|
|||
start_calls = []
|
||||
stop_calls = []
|
||||
|
||||
async def mock_start_flow(id, flow):
|
||||
start_calls.append((id, flow))
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_calls.append((id, flow))
|
||||
async def mock_start_flow(workspace, id, flow):
|
||||
start_calls.append((workspace, id, flow))
|
||||
|
||||
async def mock_stop_flow(workspace, id, flow):
|
||||
stop_calls.append((workspace, id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
await config_receiver.fetch_and_apply_workspace("default")
|
||||
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
ws_flows = config_receiver.flows["default"]
|
||||
assert "flow1" not in ws_flows
|
||||
assert "flow2" in ws_flows
|
||||
assert "flow3" in ws_flows
|
||||
assert len(start_calls) == 1
|
||||
assert start_calls[0][0] == "flow3"
|
||||
assert start_calls[0][:2] == ("default", "flow3")
|
||||
assert len(stop_calls) == 1
|
||||
assert stop_calls[0][0] == "flow1"
|
||||
assert stop_calls[0][:2] == ("default", "flow1")
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ def _ge_response_dict():
|
|||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"entities": [
|
||||
|
|
@ -59,7 +58,6 @@ def _triples_response_dict():
|
|||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"triples": [
|
||||
|
|
@ -73,9 +71,9 @@ def _triples_response_dict():
|
|||
}
|
||||
|
||||
|
||||
def _make_request(id_="doc-1", user="alice"):
|
||||
def _make_request(id_="doc-1", workspace="alice"):
|
||||
request = Mock()
|
||||
request.query = {"id": id_, "user": user}
|
||||
request.query = {"id": id_, "workspace": workspace}
|
||||
return request
|
||||
|
||||
|
||||
|
|
@ -149,12 +147,8 @@ class TestCoreExportWireFormat:
|
|||
msg_type, payload = items[0]
|
||||
assert msg_type == "ge"
|
||||
|
||||
# Metadata envelope: only id/user/collection — no stale `m["m"]`.
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
# Metadata envelope: only id/collection — no stale `m["m"]`.
|
||||
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
|
||||
|
||||
# Entities: each carries the *singular* `v` and the term envelope
|
||||
assert len(payload["e"]) == 2
|
||||
|
|
@ -202,11 +196,7 @@ class TestCoreExportWireFormat:
|
|||
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "t"
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
|
||||
assert len(payload["t"]) == 1
|
||||
|
||||
|
||||
|
|
@ -240,7 +230,7 @@ class TestCoreImportWireFormat:
|
|||
payload = msgpack.packb((
|
||||
"ge",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"m": {"i": "doc-1", "c": "testcoll"},
|
||||
"e": [
|
||||
{
|
||||
"e": {"t": "i", "i": "http://example.org/alice"},
|
||||
|
|
@ -266,7 +256,7 @@ class TestCoreImportWireFormat:
|
|||
|
||||
req = captured[0]
|
||||
assert req["operation"] == "put-kg-core"
|
||||
assert req["user"] == "alice"
|
||||
assert req["workspace"] == "alice"
|
||||
assert req["id"] == "doc-1"
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
|
|
@ -275,7 +265,6 @@ class TestCoreImportWireFormat:
|
|||
assert "metadata" not in ge["metadata"]
|
||||
assert ge["metadata"] == {
|
||||
"id": "doc-1",
|
||||
"user": "alice",
|
||||
"collection": "default",
|
||||
}
|
||||
|
||||
|
|
@ -302,7 +291,7 @@ class TestCoreImportWireFormat:
|
|||
payload = msgpack.packb((
|
||||
"t",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"m": {"i": "doc-1", "c": "testcoll"},
|
||||
"t": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
|
|
@ -407,11 +396,10 @@ class TestCoreImportExportRoundTrip:
|
|||
original = _ge_response_dict()["graph-embeddings"]
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# The import side overrides id/user from the URL query (intentional),
|
||||
# The import side overrides id from the URL query (intentional),
|
||||
# so we only round-trip the entity payload itself.
|
||||
assert ge["metadata"]["id"] == original["metadata"]["id"]
|
||||
assert ge["metadata"]["user"] == original["metadata"]["user"]
|
||||
|
||||
|
||||
assert len(ge["entities"]) == len(original["entities"])
|
||||
for got, want in zip(ge["entities"], original["entities"]):
|
||||
assert got["vector"] == want["vector"]
|
||||
|
|
|
|||
|
|
@ -72,10 +72,10 @@ class TestDispatcherManager:
|
|||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await manager.start_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" in manager.flows
|
||||
assert manager.flows["flow1"] == flow_data
|
||||
await manager.start_flow("default", "flow1", flow_data)
|
||||
|
||||
assert ("default", "flow1") in manager.flows
|
||||
assert manager.flows[("default", "flow1")] == flow_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow(self):
|
||||
|
|
@ -86,11 +86,11 @@ class TestDispatcherManager:
|
|||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
manager.flows["flow1"] = flow_data
|
||||
|
||||
await manager.stop_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" not in manager.flows
|
||||
manager.flows[("default", "flow1")] = flow_data
|
||||
|
||||
await manager.stop_flow("default", "flow1", flow_data)
|
||||
|
||||
assert ("default", "flow1") not in manager.flows
|
||||
|
||||
def test_dispatch_global_service_returns_wrapper(self):
|
||||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
|
|
@ -275,12 +275,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
|
|
@ -290,7 +290,7 @@ class TestDispatcherManager:
|
|||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
|
||||
params = {"flow": "test_flow", "kind": "triples"}
|
||||
result = await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
|
|
@ -326,12 +326,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
|
||||
mock_dispatchers.__contains__.return_value = False
|
||||
|
||||
|
|
@ -348,12 +348,12 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"flow": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
|
|
@ -404,7 +404,7 @@ class TestDispatcherManager:
|
|||
params = {"flow": "test_flow", "kind": "agent"}
|
||||
result = await manager.process_flow_service("data", "responder", params)
|
||||
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
|
||||
assert result == "flow_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -415,14 +415,14 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
||||
|
||||
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
assert result == "cached_result"
|
||||
|
|
@ -435,7 +435,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
|
|
@ -443,7 +443,7 @@ class TestDispatcherManager:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
|
|
@ -452,23 +452,23 @@ class TestDispatcherManager:
|
|||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
request_queue="agent_request_queue",
|
||||
response_queue="agent_response_queue",
|
||||
timeout=120,
|
||||
consumer="api-gateway-test_flow-agent-request",
|
||||
subscriber="api-gateway-test_flow-agent-request"
|
||||
consumer="api-gateway-default-test_flow-agent-request",
|
||||
subscriber="api-gateway-default-test_flow-agent-request"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
|
||||
assert result == "new_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -479,26 +479,26 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"text-load": {"flow": "text_load_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
|
||||
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
|
||||
mock_rr_dispatchers.__contains__.return_value = False
|
||||
mock_sender_dispatchers.__contains__.return_value = True
|
||||
|
||||
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="sender_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
|
||||
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
backend=mock_backend,
|
||||
|
|
@ -506,9 +506,9 @@ class TestDispatcherManager:
|
|||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
|
||||
assert result == "sender_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -519,7 +519,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
||||
|
|
@ -529,14 +529,14 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"text-completion": {"request": "req", "response": "resp"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_kind(self):
|
||||
|
|
@ -546,7 +546,7 @@ class TestDispatcherManager:
|
|||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"invalid-kind": {"request": "req", "response": "resp"}
|
||||
}
|
||||
|
|
@ -558,7 +558,7 @@ class TestDispatcherManager:
|
|||
mock_sender_dispatchers.__contains__.return_value = False
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
|
||||
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
|
||||
|
|
@ -608,7 +608,7 @@ class TestDispatcherManager:
|
|||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
manager.flows["test_flow"] = {
|
||||
manager.flows[("default", "test_flow")] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
|
|
@ -630,7 +630,7 @@ class TestDispatcherManager:
|
|||
mock_rr_dispatchers.__contains__.return_value = True
|
||||
|
||||
results = await asyncio.gather(*[
|
||||
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
|
||||
for _ in range(5)
|
||||
])
|
||||
|
||||
|
|
@ -638,5 +638,5 @@ class TestDispatcherManager:
|
|||
"Dispatcher class instantiated more than once — duplicate consumer bug"
|
||||
)
|
||||
assert mock_dispatcher.start.call_count == 1
|
||||
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
|
||||
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
|
||||
assert all(r == "result" for r in results)
|
||||
|
|
@ -186,7 +186,6 @@ class TestEntityContextsImportMessageProcessing:
|
|||
assert isinstance(sent, EntityContexts)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
|
|
|
|||
|
|
@ -188,7 +188,6 @@ class TestGraphEmbeddingsImportMessageProcessing:
|
|||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
|
|
|
|||
|
|
@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing:
|
|||
|
||||
# Check metadata
|
||||
assert sent_object.metadata.id == "obj-123"
|
||||
assert sent_object.metadata.user == "testuser"
|
||||
assert sent_object.metadata.collection == "testcollection"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ class TestTextDocumentTranslator:
|
|||
)
|
||||
|
||||
assert msg.metadata.id == "doc-1"
|
||||
assert msg.metadata.user == "alice"
|
||||
assert msg.metadata.collection == "research"
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
|
|
|
|||
|
|
@ -29,10 +29,9 @@ class Triple:
|
|||
self.o = o
|
||||
|
||||
class Metadata:
|
||||
def __init__(self, id, user, collection, root=""):
|
||||
def __init__(self, id, collection, root=""):
|
||||
self.id = id
|
||||
self.root = root
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
|
||||
class Triples:
|
||||
|
|
@ -108,7 +107,6 @@ def sample_triples(sample_triple):
|
|||
"""Sample Triples batch object"""
|
||||
metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -123,7 +121,6 @@ def sample_chunk():
|
|||
"""Sample text chunk for processing"""
|
||||
metadata = Metadata(
|
||||
id="test-chunk-456",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -322,7 +322,6 @@ This is not JSON at all
|
|||
assert isinstance(sent_triples, Triples)
|
||||
# Check metadata fields individually since implementation creates new Metadata object
|
||||
assert sent_triples.metadata.id == sample_metadata.id
|
||||
assert sent_triples.metadata.user == sample_metadata.user
|
||||
assert sent_triples.metadata.collection == sample_metadata.collection
|
||||
assert len(sent_triples.triples) == 1
|
||||
assert sent_triples.triples[0].s.iri == "test:subject"
|
||||
|
|
@ -346,7 +345,6 @@ This is not JSON at all
|
|||
assert isinstance(sent_contexts, EntityContexts)
|
||||
# Check metadata fields individually since implementation creates new Metadata object
|
||||
assert sent_contexts.metadata.id == sample_metadata.id
|
||||
assert sent_contexts.metadata.user == sample_metadata.user
|
||||
assert sent_contexts.metadata.collection == sample_metadata.collection
|
||||
assert len(sent_contexts.entities) == 1
|
||||
assert sent_contexts.entities[0].entity.iri == "test:entity"
|
||||
|
|
|
|||
|
|
@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic:
|
|||
"""Test ExtractedObject creation and properties"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-extraction-001",
|
||||
user="test_user",
|
||||
id="test-extraction-001",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic:
|
|||
assert extracted_obj.values[0]["customer_id"] == "CUST001"
|
||||
assert extracted_obj.confidence == 0.95
|
||||
assert "John Doe" in extracted_obj.source_span
|
||||
assert extracted_obj.metadata.user == "test_user"
|
||||
|
||||
def test_config_parsing_error_handling(self):
|
||||
"""Test configuration parsing with invalid JSON"""
|
||||
|
|
|
|||
|
|
@ -371,7 +371,6 @@ class TestTripleConstructionLogic:
|
|||
|
||||
metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
|
@ -384,7 +383,6 @@ class TestTripleConstructionLogic:
|
|||
# Assert
|
||||
assert isinstance(triples_batch, Triples)
|
||||
assert triples_batch.metadata.id == "test-doc-123"
|
||||
assert triples_batch.metadata.user == "test_user"
|
||||
assert triples_batch.metadata.collection == "test_collection"
|
||||
assert len(triples_batch.triples) == 2
|
||||
|
||||
|
|
|
|||
|
|
@ -33,12 +33,12 @@ def _make_librarian(min_chunk_size=1):
|
|||
|
||||
|
||||
def _make_doc_metadata(
|
||||
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
|
||||
doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc"
|
||||
):
|
||||
meta = MagicMock()
|
||||
meta.id = doc_id
|
||||
meta.kind = kind
|
||||
meta.user = user
|
||||
meta.workspace = workspace
|
||||
meta.title = title
|
||||
meta.time = 1700000000
|
||||
meta.comments = ""
|
||||
|
|
@ -47,27 +47,27 @@ def _make_doc_metadata(
|
|||
|
||||
|
||||
def _make_begin_request(
|
||||
doc_id="doc-1", kind="application/pdf", user="alice",
|
||||
doc_id="doc-1", kind="application/pdf", workspace="alice",
|
||||
total_size=10_000_000, chunk_size=0
|
||||
):
|
||||
req = MagicMock()
|
||||
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
|
||||
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace)
|
||||
req.total_size = total_size
|
||||
req.chunk_size = chunk_size
|
||||
return req
|
||||
|
||||
|
||||
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
|
||||
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"):
|
||||
req = MagicMock()
|
||||
req.upload_id = upload_id
|
||||
req.chunk_index = chunk_index
|
||||
req.user = user
|
||||
req.workspace = workspace
|
||||
req.content = base64.b64encode(content)
|
||||
return req
|
||||
|
||||
|
||||
def _make_session(
|
||||
user="alice", total_chunks=5, chunk_size=2_000_000,
|
||||
workspace="alice", total_chunks=5, chunk_size=2_000_000,
|
||||
total_size=10_000_000, chunks_received=None, object_id="obj-1",
|
||||
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
|
||||
):
|
||||
|
|
@ -76,11 +76,11 @@ def _make_session(
|
|||
if document_metadata is None:
|
||||
document_metadata = json.dumps({
|
||||
"id": document_id, "kind": "application/pdf",
|
||||
"user": user, "title": "Test", "time": 1700000000,
|
||||
"workspace": workspace, "title": "Test", "time": 1700000000,
|
||||
"comments": "", "tags": [],
|
||||
})
|
||||
return {
|
||||
"user": user,
|
||||
"workspace": workspace,
|
||||
"total_chunks": total_chunks,
|
||||
"chunk_size": chunk_size,
|
||||
"total_size": total_size,
|
||||
|
|
@ -259,10 +259,10 @@ class TestUploadChunk:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = _make_upload_chunk_request(user="bob")
|
||||
req = _make_upload_chunk_request(workspace="bob")
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
|
|
@ -353,7 +353,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.complete_upload(req)
|
||||
|
||||
|
|
@ -375,7 +375,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
await lib.complete_upload(req)
|
||||
|
||||
|
|
@ -394,7 +394,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="Missing chunks"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -406,7 +406,7 @@ class TestCompleteUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-gone"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -414,12 +414,12 @@ class TestCompleteUpload:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.complete_upload(req)
|
||||
|
|
@ -439,7 +439,7 @@ class TestAbortUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.abort_upload(req)
|
||||
|
||||
|
|
@ -456,7 +456,7 @@ class TestAbortUpload:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-gone"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.abort_upload(req)
|
||||
|
|
@ -464,12 +464,12 @@ class TestAbortUpload:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.abort_upload(req)
|
||||
|
|
@ -492,7 +492,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -510,7 +510,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-expired"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -527,7 +527,7 @@ class TestGetUploadStatus:
|
|||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
|
|
@ -539,12 +539,12 @@ class TestGetUploadStatus:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
session = _make_session(workspace="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
req.workspace = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.get_upload_status(req)
|
||||
|
|
@ -564,7 +564,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -587,7 +587,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -608,7 +608,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
|
|
@ -630,7 +630,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=b"x")
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 0 # Should use default 1MB
|
||||
|
||||
|
|
@ -649,7 +649,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_range = AsyncMock(return_value=raw)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 1000
|
||||
|
||||
|
|
@ -666,7 +666,7 @@ class TestStreamDocument:
|
|||
lib.blob_store.get_size = AsyncMock(return_value=5000)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 512
|
||||
|
||||
|
|
@ -698,7 +698,7 @@ class TestListUploads:
|
|||
]
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.list_uploads(req)
|
||||
|
||||
|
|
@ -713,7 +713,7 @@ class TestListUploads:
|
|||
lib.table_store.list_upload_sessions.return_value = []
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.workspace = "alice"
|
||||
|
||||
resp = await lib.list_uploads(req)
|
||||
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ def _make_processor(tools=None):
|
|||
agent = MagicMock()
|
||||
agent.tools = tools or {}
|
||||
agent.additional_context = ""
|
||||
processor.agent = agent
|
||||
processor.agents = {"default": agent}
|
||||
processor.aggregator = MagicMock()
|
||||
|
||||
return processor
|
||||
|
|
@ -254,6 +254,7 @@ def _make_flow():
|
|||
return producers[name]
|
||||
|
||||
flow = MagicMock(side_effect=factory)
|
||||
flow.workspace = "default"
|
||||
return flow
|
||||
|
||||
|
||||
|
|
@ -299,7 +300,7 @@ class TestAgentReactDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
@ -344,7 +345,6 @@ class TestAgentReactDagStructure:
|
|||
|
||||
request1 = AgentRequest(
|
||||
question="What is 6x7?",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
|
|
@ -433,7 +433,7 @@ class TestAgentPlanDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
@ -480,7 +480,6 @@ class TestAgentPlanDagStructure:
|
|||
# Iteration 1: planning
|
||||
request1 = AgentRequest(
|
||||
question="Test?",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
|
|
@ -537,7 +536,7 @@ class TestAgentSupervisorDagStructure:
|
|||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.agents = processor.agents
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
|
|
@ -563,7 +562,6 @@ class TestAgentSupervisorDagStructure:
|
|||
|
||||
request = AgentRequest(
|
||||
question="Research quantum computing",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=str(uuid.uuid4()),
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
|
|
@ -69,7 +68,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -83,7 +81,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with correct parameters including user/collection
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -101,7 +99,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_longer_vector(self, processor):
|
||||
"""Test querying document embeddings with a longer vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=3
|
||||
|
|
@ -115,7 +112,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -133,7 +130,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_with_limit(self, processor):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
|
|
@ -148,7 +144,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the specified limit
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -162,13 +158,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying document embeddings with empty vectors list"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -180,7 +175,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying document embeddings with empty search results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -189,7 +183,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -203,7 +197,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_unicode_documents(self, processor):
|
||||
"""Test querying document embeddings with Unicode document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -217,7 +210,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify Unicode content is preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
@ -230,7 +223,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_large_documents(self, processor):
|
||||
"""Test querying document embeddings with large document content"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -244,7 +236,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify large content is preserved in ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -256,7 +248,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_special_characters(self, processor):
|
||||
"""Test querying document embeddings with special characters in documents"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -270,7 +261,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify special characters are preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
@ -283,13 +274,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -301,13 +291,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying document embeddings with negative limit"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=-1
|
||||
)
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for negative limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -319,7 +308,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -330,13 +318,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_document_embeddings(query)
|
||||
await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
|
||||
limit=5
|
||||
|
|
@ -349,7 +336,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
|
@ -364,7 +351,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_multiple_results(self, processor):
|
||||
"""Test querying document embeddings with multiple results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -378,7 +364,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
result = await processor.query_document_embeddings('test_user', query)
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was accessed correctly (with dimension suffix)
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
chunks = await processor.query_document_embeddings(mock_query_message)
|
||||
chunks = await processor.query_document_embeddings('default', mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
|
@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify limit is passed to query
|
||||
mock_index.query.assert_called_once()
|
||||
|
|
@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify different indexes used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
|
|
@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify empty results
|
||||
assert chunks == []
|
||||
|
|
@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify Unicode content is properly handled
|
||||
assert len(chunks) == 2
|
||||
|
|
@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify large content is properly handled
|
||||
assert len(chunks) == 1
|
||||
|
|
@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify all content types are properly handled
|
||||
assert len(chunks) == 5
|
||||
|
|
@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_index_access_failure(self, processor):
|
||||
|
|
@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
processor.pinecone.Index.side_effect = Exception("Index access failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index access failed"):
|
||||
await processor.query_document_embeddings(message)
|
||||
await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_vector_accumulation(self, processor):
|
||||
|
|
@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
|
||||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
chunks = await processor.query_document_embeddings('test_user', message)
|
||||
|
||||
# Verify all queries were made
|
||||
assert mock_index.query.call_count == 3
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters (with dimension suffix)
|
||||
|
|
@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('limit_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with exact limit (no multiplication)
|
||||
|
|
@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
|
@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('dim_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once with correct collection
|
||||
|
|
@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'utf8_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('utf8_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
|
@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
await processor.query_document_embeddings('error_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('zero_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
|
|
@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'large_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
result = await processor.query_document_embeddings('large_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should query with full limit
|
||||
|
|
@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Act & Assert
|
||||
# This should raise a KeyError when trying to access payload['chunk_id']
|
||||
with pytest.raises(KeyError):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
await processor.query_document_embeddings('payload_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
def mock_query_request(self):
|
||||
"""Create a mock query request for testing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
|
|
@ -117,7 +116,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -131,7 +129,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with correct parameters including user/collection
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -154,7 +152,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_multiple_results(self, processor):
|
||||
"""Test querying graph embeddings returns multiple results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -168,7 +165,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -186,7 +183,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_with_limit(self, processor):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
|
|
@ -201,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with 2*limit for better deduplication
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -215,7 +211,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_preserves_order(self, processor):
|
||||
"""Test that query results preserve order from the vector store"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
|
|
@ -229,7 +224,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify results are in the same order as returned by the store
|
||||
assert len(result) == 3
|
||||
|
|
@ -241,7 +236,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_results_limited(self, processor):
|
||||
"""Test that results are properly limited when store returns more than requested"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=2
|
||||
|
|
@ -255,7 +249,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -269,13 +263,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_empty_vectors(self, processor):
|
||||
"""Test querying graph embeddings with empty vectors list"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -287,7 +280,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_empty_search_results(self, processor):
|
||||
"""Test querying graph embeddings with empty search results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -296,7 +288,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
# Mock empty search results
|
||||
processor.vecstore.search.return_value = []
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
|
|
@ -310,7 +302,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
|
||||
"""Test querying graph embeddings with mixed URI and literal results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -325,7 +316,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify all results are properly typed
|
||||
assert len(result) == 4
|
||||
|
|
@ -348,7 +339,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test exception handling during query processing"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
|
|
@ -359,7 +349,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Milvus connection failed"):
|
||||
await processor.query_graph_embeddings(query)
|
||||
await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
@ -430,13 +420,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify no search was called (optimization for zero limit)
|
||||
processor.vecstore.search.assert_not_called()
|
||||
|
|
@ -448,7 +437,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_longer_vector(self, processor):
|
||||
"""Test querying graph embeddings with a longer vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
limit=5
|
||||
|
|
@ -461,7 +449,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
result = await processor.query_graph_embeddings('test_user', query)
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was accessed correctly (with dimension suffix)
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||
entities = await processor.query_graph_embeddings('default', mock_query_message)
|
||||
|
||||
# Verify query was made once
|
||||
assert mock_index.query.call_count == 1
|
||||
|
|
@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify limit is respected
|
||||
assert len(entities) == 2
|
||||
|
|
@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no query was made and empty result returned
|
||||
mock_index.query.assert_not_called()
|
||||
|
|
@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify correct index used for 2D vector
|
||||
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
|
||||
|
|
@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no queries were made and empty result returned
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_results.matches = []
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify empty results
|
||||
assert entities == []
|
||||
|
|
@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Should get exactly 3 unique entities (respecting limit)
|
||||
assert len(entities) == 3
|
||||
|
|
@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
entities = await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
# Should only return 2 entities (respecting limit)
|
||||
mock_index.query.assert_called_once()
|
||||
|
|
@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
mock_index.query.side_effect = Exception("Query failed")
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_graph_embeddings(message)
|
||||
await processor.query_graph_embeddings('test_user', message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters (with dimension suffix)
|
||||
|
|
@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('limit_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with limit * 2
|
||||
|
|
@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
|
@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('dim_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called once
|
||||
|
|
@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'uri_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('uri_user', mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
|
@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_graph_embeddings(mock_message)
|
||||
await processor.query_graph_embeddings('error_user', mock_message)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
result = await processor.query_graph_embeddings('zero_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor
|
|||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphQueryUserCollectionIsolation:
|
||||
class TestMemgraphQueryWorkspaceCollectionIsolation:
|
||||
"""Test cases for Memgraph query service with user/collection isolation"""
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_spo_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SPO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -22,7 +22,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -32,13 +31,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SPO query for literal includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -48,14 +47,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
value="test_object",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_sp_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SP query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -63,7 +62,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -73,13 +71,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SP query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -88,14 +86,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_literal_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_so_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_so_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -103,7 +101,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -113,13 +110,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SO query for nodes includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -128,14 +125,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
src="http://example.com/s",
|
||||
uri="http://example.com/o",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_s_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test S-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -143,7 +140,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -153,13 +149,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify S query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -167,14 +163,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
src="http://example.com/s",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_po_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_po_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test PO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -182,7 +178,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -192,13 +187,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify PO query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -207,14 +202,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
value="literal",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_p_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test P-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -222,7 +217,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -232,13 +226,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify P query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -246,14 +240,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_o_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test O-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -261,7 +255,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -271,13 +264,13 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify O query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
|
@ -285,14 +278,14 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
value="test_value",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test wildcard query (all None) includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -300,7 +293,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -310,36 +302,36 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify wildcard query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
||||
# Verify wildcard query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT 1000"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_node_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='memgraph'
|
||||
)
|
||||
|
|
@ -363,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('default', query)
|
||||
|
||||
# Verify defaults were used
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
|
|
@ -383,7 +375,6 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -410,7 +401,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
([mock_record2], MagicMock(), MagicMock()) # Node query
|
||||
]
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify results are proper Triple objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor
|
|||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jQueryUserCollectionIsolation:
|
||||
class TestNeo4jQueryWorkspaceCollectionIsolation:
|
||||
"""Test cases for Neo4j query service with user/collection isolation"""
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_spo_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SPO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -22,7 +22,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -32,13 +31,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SPO query for literal includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN $src as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -48,14 +47,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
value="test_object",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_sp_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SP query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -63,7 +62,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -73,13 +71,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SP query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -88,16 +86,16 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_literal_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
# Verify SP query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -106,14 +104,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_node_query,
|
||||
src="http://example.com/s",
|
||||
rel="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_so_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_so_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test SO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -121,7 +119,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -131,13 +128,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify SO query for nodes includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -146,14 +143,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
src="http://example.com/s",
|
||||
uri="http://example.com/o",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_s_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test S-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -161,7 +158,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -171,13 +167,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify S query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -185,14 +181,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
src="http://example.com/s",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_po_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_po_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test PO query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -200,7 +196,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -210,13 +205,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify PO query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -225,14 +220,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
value="literal",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_p_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test P-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -240,7 +235,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
|
|
@ -250,13 +244,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify P query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -264,14 +258,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
uri="http://example.com/p",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_o_only_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test O-only query pattern includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -279,7 +273,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -289,13 +282,13 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify O query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
|
@ -303,14 +296,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
value="test_value",
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
|
||||
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test wildcard query (all None) includes user/collection filtering"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
|
@ -318,7 +311,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -328,36 +320,36 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify wildcard query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
# Verify wildcard query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Node {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_node_query,
|
||||
user="test_user",
|
||||
workspace="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
|
@ -381,7 +373,7 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('default', query)
|
||||
|
||||
# Verify defaults were used
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
|
|
@ -401,7 +393,6 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
|
|
@ -428,7 +419,7 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
([mock_record2], MagicMock(), MagicMock()) # Node query
|
||||
]
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples("test_user", query)
|
||||
|
||||
# Verify results are proper Triple objects
|
||||
assert len(result) == 2
|
||||
|
|
@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
|
|||
"""Test parsing of schema configuration"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.schema_builders = {}
|
||||
processor.graphql_schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.schema_builder = MagicMock()
|
||||
processor.schema_builder.clear = MagicMock()
|
||||
processor.schema_builder.add_schema = MagicMock()
|
||||
processor.schema_builder.build = MagicMock(return_value=MagicMock())
|
||||
processor.query_cassandra = MagicMock()
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test config
|
||||
|
|
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
|
|||
}
|
||||
|
||||
# Process config
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
await processor.on_schema_config("default", schema_config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer" in processor.schemas
|
||||
schema = processor.schemas["customer"]
|
||||
assert "customer" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["customer"]
|
||||
assert schema.name == "customer"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
|
@ -147,39 +146,40 @@ class TestRowsGraphQLQueryLogic:
|
|||
status_field = next(f for f in schema.fields if f.name == "status")
|
||||
assert status_field.enum_values == ["active", "inactive"]
|
||||
|
||||
# Verify schema builder was called
|
||||
processor.schema_builder.add_schema.assert_called_once()
|
||||
processor.schema_builder.build.assert_called_once()
|
||||
# Verify per-workspace schema builder was created and graphql schema built
|
||||
assert "default" in processor.schema_builders
|
||||
assert "default" in processor.graphql_schemas
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_context_handling(self):
|
||||
"""Test GraphQL execution context setup"""
|
||||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
graphql_schema = AsyncMock()
|
||||
processor.graphql_schemas = {"default": graphql_schema}
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Mock schema execution
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
||||
mock_result.errors = None
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
graphql_schema.execute.return_value = mock_result
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
workspace="default",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify schema.execute was called with correct context
|
||||
processor.graphql_schema.execute.assert_called_once()
|
||||
call_args = processor.graphql_schema.execute.call_args
|
||||
graphql_schema.execute.assert_called_once()
|
||||
call_args = graphql_schema.execute.call_args
|
||||
|
||||
# Verify context was passed
|
||||
context = call_args[1]['context_value']
|
||||
assert context["processor"] == processor
|
||||
assert context["user"] == "test_user"
|
||||
assert context["workspace"] == "default"
|
||||
assert context["collection"] == "test_collection"
|
||||
|
||||
# Verify result structure
|
||||
|
|
@ -190,7 +190,8 @@ class TestRowsGraphQLQueryLogic:
|
|||
async def test_error_handling_graphql_errors(self):
|
||||
"""Test GraphQL error handling and conversion"""
|
||||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
graphql_schema = AsyncMock()
|
||||
processor.graphql_schemas = {"default": graphql_schema}
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Create a simple object to simulate GraphQL error
|
||||
|
|
@ -212,13 +213,13 @@ class TestRowsGraphQLQueryLogic:
|
|||
mock_result = MagicMock()
|
||||
mock_result.data = None
|
||||
mock_result.errors = [mock_error]
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
graphql_schema.execute.return_value = mock_result
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
workspace="default",
|
||||
query='{ customers { invalid_field } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -248,7 +249,6 @@ class TestRowsGraphQLQueryLogic:
|
|||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
|
|
@ -259,6 +259,7 @@ class TestRowsGraphQLQueryLogic:
|
|||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.workspace = "default"
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
|
@ -267,10 +268,10 @@ class TestRowsGraphQLQueryLogic:
|
|||
|
||||
# Verify query was executed
|
||||
processor.execute_graphql_query.assert_called_once_with(
|
||||
workspace="default",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -297,7 +298,6 @@ class TestRowsGraphQLQueryLogic:
|
|||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ invalid_query }',
|
||||
variables={},
|
||||
|
|
@ -357,7 +357,7 @@ class TestUnifiedTableQueries:
|
|||
|
||||
# Query with filter on indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
|
|
@ -374,7 +374,7 @@ class TestUnifiedTableQueries:
|
|||
query = call_args[0][1]
|
||||
params = call_args[0][2]
|
||||
|
||||
assert "SELECT data, source FROM test_user.rows" in query
|
||||
assert "SELECT data, source FROM test_workspace.rows" in query
|
||||
assert "collection = %s" in query
|
||||
assert "schema_name = %s" in query
|
||||
assert "index_name = %s" in query
|
||||
|
|
@ -421,7 +421,7 @@ class TestUnifiedTableQueries:
|
|||
|
||||
# Query with filter on non-indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -103,7 +102,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify KnowledgeGraph was created with correct parameters
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -170,7 +169,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -178,7 +176,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
|
|
@ -207,7 +205,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -215,7 +212,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
|
|
@ -244,7 +241,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -252,7 +248,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=10
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
|
|
@ -281,7 +277,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -289,7 +284,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=75
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
|
|
@ -319,7 +314,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -327,7 +321,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
|
|
@ -425,7 +419,6 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -433,7 +426,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify KnowledgeGraph was created with authentication
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -463,7 +456,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -472,11 +464,11 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -504,7 +496,6 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -512,12 +503,11 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query1)
|
||||
await processor.query_triples('user1', query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -525,7 +515,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query2)
|
||||
await processor.query_triples('user2', query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
|
|
@ -544,7 +534,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -553,7 +542,7 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -582,7 +571,6 @@ class TestCassandraQueryProcessor:
|
|||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -590,7 +578,7 @@ class TestCassandraQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
|
|
@ -621,7 +609,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# PO query pattern (predicate + object, find subjects)
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
|
|
@ -629,7 +616,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
|
|
@ -662,7 +649,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# OS query pattern (object + subject, find predicates)
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
|
|
@ -670,7 +656,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
|
|
@ -721,7 +707,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_tg_instance.reset_mock()
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=LITERAL, value=s) if s else None,
|
||||
p=Term(type=LITERAL, value=p) if p else None,
|
||||
|
|
@ -729,7 +714,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=10
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify the correct method was called
|
||||
method = getattr(mock_tg_instance, expected_method)
|
||||
|
|
@ -780,7 +765,6 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# This is the query pattern that was slow with ALLOW FILTERING
|
||||
query = TriplesQueryRequest(
|
||||
user='large_dataset_user',
|
||||
collection='massive_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
|
||||
|
|
@ -788,7 +772,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('large_dataset_user', query)
|
||||
|
||||
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
|
|
|
|||
|
|
@ -123,7 +123,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -131,7 +130,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -164,7 +163,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -172,7 +170,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -209,7 +207,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -217,7 +214,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -254,7 +251,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -262,7 +258,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -299,7 +295,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -307,7 +302,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -344,7 +339,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -352,7 +346,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -389,7 +383,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -397,7 +390,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -434,7 +427,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -442,7 +434,7 @@ class TestFalkorDBQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_graph.query.call_count == 2
|
||||
|
|
@ -474,7 +466,6 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -484,7 +475,7 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -122,7 +122,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -130,7 +129,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -164,7 +163,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -172,7 +170,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -210,7 +208,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -218,7 +215,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -256,7 +253,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -264,7 +260,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -302,7 +298,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -310,7 +305,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -348,7 +343,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -356,7 +350,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -394,7 +388,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -402,7 +395,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -440,7 +433,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -448,7 +440,7 @@ class TestMemgraphQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -478,7 +470,6 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -488,7 +479,7 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -122,7 +122,6 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -130,7 +129,7 @@ class TestNeo4jQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -164,7 +163,6 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
|
|
@ -172,7 +170,7 @@ class TestNeo4jQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -210,7 +208,6 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
|
|
@ -218,7 +215,7 @@ class TestNeo4jQueryProcessor:
|
|||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify both literal and URI queries were executed
|
||||
assert mock_driver.execute_query.call_count == 2
|
||||
|
|
@ -248,7 +245,6 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Create query request
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
|
|
@ -258,7 +254,7 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
await processor.query_triples('test_user', query)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TestDocumentMetadataTranslator:
|
|||
"title": "Test Document",
|
||||
"comments": "No comments",
|
||||
"metadata": [],
|
||||
"user": "alice",
|
||||
"workspace": "alice",
|
||||
"tags": ["finance", "q4"],
|
||||
"parent-id": "doc-100",
|
||||
"document-type": "page",
|
||||
|
|
@ -40,14 +40,14 @@ class TestDocumentMetadataTranslator:
|
|||
assert obj.time == 1710000000
|
||||
assert obj.kind == "application/pdf"
|
||||
assert obj.title == "Test Document"
|
||||
assert obj.user == "alice"
|
||||
assert obj.workspace == "alice"
|
||||
assert obj.tags == ["finance", "q4"]
|
||||
assert obj.parent_id == "doc-100"
|
||||
assert obj.document_type == "page"
|
||||
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["id"] == "doc-123"
|
||||
assert wire["user"] == "alice"
|
||||
assert wire["workspace"] == "alice"
|
||||
assert wire["parent-id"] == "doc-100"
|
||||
assert wire["document-type"] == "page"
|
||||
|
||||
|
|
@ -80,10 +80,10 @@ class TestDocumentMetadataTranslator:
|
|||
|
||||
def test_falsy_fields_omitted_from_wire(self):
|
||||
"""Empty string fields should be omitted from wire format."""
|
||||
obj = DocumentMetadata(id="", time=0, user="")
|
||||
obj = DocumentMetadata(id="", time=0, workspace="")
|
||||
wire = self.tx.encode(obj)
|
||||
assert "id" not in wire
|
||||
assert "user" not in wire
|
||||
assert "workspace" not in wire
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -101,7 +101,7 @@ class TestProcessingMetadataTranslator:
|
|||
"document-id": "doc-123",
|
||||
"time": 1710000000,
|
||||
"flow": "default",
|
||||
"user": "alice",
|
||||
"workspace": "alice",
|
||||
"collection": "my-collection",
|
||||
"tags": ["tag1"],
|
||||
}
|
||||
|
|
@ -109,20 +109,20 @@ class TestProcessingMetadataTranslator:
|
|||
assert obj.id == "proc-1"
|
||||
assert obj.document_id == "doc-123"
|
||||
assert obj.flow == "default"
|
||||
assert obj.user == "alice"
|
||||
assert obj.workspace == "alice"
|
||||
assert obj.collection == "my-collection"
|
||||
assert obj.tags == ["tag1"]
|
||||
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["id"] == "proc-1"
|
||||
assert wire["document-id"] == "doc-123"
|
||||
assert wire["user"] == "alice"
|
||||
assert wire["workspace"] == "alice"
|
||||
assert wire["collection"] == "my-collection"
|
||||
|
||||
def test_missing_fields_use_defaults(self):
|
||||
obj = self.tx.decode({})
|
||||
assert obj.id is None
|
||||
assert obj.user is None
|
||||
assert obj.workspace is None
|
||||
assert obj.collection is None
|
||||
|
||||
def test_tags_none_omitted(self):
|
||||
|
|
@ -135,10 +135,10 @@ class TestProcessingMetadataTranslator:
|
|||
wire = self.tx.encode(obj)
|
||||
assert wire["tags"] == []
|
||||
|
||||
def test_user_and_collection_preserved(self):
|
||||
def test_workspace_and_collection_preserved(self):
|
||||
"""Core pipeline routing fields must survive round-trip."""
|
||||
data = {"user": "bob", "collection": "research"}
|
||||
data = {"workspace": "bob", "collection": "research"}
|
||||
obj = self.tx.decode(data)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["user"] == "bob"
|
||||
assert wire["workspace"] == "bob"
|
||||
assert wire["collection"] == "research"
|
||||
|
|
|
|||
|
|
@ -61,7 +61,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -69,7 +68,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = [] # Empty vector
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
|
||||
# No upsert should be called
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
|
@ -83,7 +82,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -91,7 +89,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = None # None vector
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -103,7 +101,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -111,7 +108,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = [0.1, 0.2, 0.3]
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -124,7 +121,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -132,7 +128,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = [0.1, 0.2, 0.3]
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -146,7 +142,6 @@ class TestDocEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "alice"
|
||||
msg.metadata.collection = "docs"
|
||||
|
||||
emb = MagicMock()
|
||||
|
|
@ -154,7 +149,7 @@ class TestDocEmbeddingsNullProtection:
|
|||
emb.vector = [0.0] * 384 # 384-dim vector
|
||||
msg.chunks = [emb]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("alice", msg)
|
||||
|
||||
call_args = proc.qdrant.upsert.call_args
|
||||
assert "d_alice_docs_384" in call_args[1]["collection_name"]
|
||||
|
|
@ -175,7 +170,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -183,7 +177,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.vector = [0.1, 0.2, 0.3]
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -195,7 +189,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -203,7 +196,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.vector = [0.1, 0.2, 0.3]
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -215,7 +208,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -223,7 +215,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.vector = [] # Empty vector
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -236,7 +228,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "col1"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -245,7 +236,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.chunk_id = "c1"
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -258,7 +249,6 @@ class TestGraphEmbeddingsNullProtection:
|
|||
proc.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "alice"
|
||||
msg.metadata.collection = "graphs"
|
||||
|
||||
entity = MagicMock()
|
||||
|
|
@ -267,7 +257,7 @@ class TestGraphEmbeddingsNullProtection:
|
|||
entity.chunk_id = ""
|
||||
msg.entities = [entity]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("alice", msg)
|
||||
|
||||
# Collection should be created with correct dimension
|
||||
proc.qdrant.create_collection.assert_called_once()
|
||||
|
|
@ -290,11 +280,10 @@ class TestCollectionValidation:
|
|||
proc.collection_exists = MagicMock(return_value=False)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "deleted-col"
|
||||
msg.chunks = [MagicMock()]
|
||||
|
||||
await proc.store_document_embeddings(msg)
|
||||
await proc.store_document_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -306,9 +295,8 @@ class TestCollectionValidation:
|
|||
proc.collection_exists = MagicMock(return_value=False)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.user = "user1"
|
||||
msg.metadata.collection = "deleted-col"
|
||||
msg.entities = [MagicMock()]
|
||||
|
||||
await proc.store_graph_embeddings(msg)
|
||||
await proc.store_graph_embeddings("user1", msg)
|
||||
proc.qdrant.upsert.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -92,14 +92,13 @@ class TestQuery:
|
|||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.doc_limit == 20 # Default value
|
||||
|
|
@ -112,7 +111,7 @@ class TestQuery:
|
|||
# Initialize Query with custom doc_limit
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="custom_user",
|
||||
workspace="test_workspace",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
doc_limit=50
|
||||
|
|
@ -120,7 +119,6 @@ class TestQuery:
|
|||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.doc_limit == 50
|
||||
|
|
@ -137,7 +135,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -162,7 +160,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -184,7 +182,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -223,7 +221,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=15
|
||||
|
|
@ -240,7 +238,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=15,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -286,7 +283,6 @@ class TestQuery:
|
|||
|
||||
result = await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10
|
||||
)
|
||||
|
|
@ -304,7 +300,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
|
@ -350,7 +345,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[[0.1, 0.2]],
|
||||
limit=20, # Default doc_limit
|
||||
user="trustgraph", # Default user
|
||||
collection="default" # Default collection
|
||||
)
|
||||
|
||||
|
|
@ -380,7 +374,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=True,
|
||||
doc_limit=5
|
||||
|
|
@ -453,7 +447,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -509,7 +503,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
|
@ -558,7 +552,6 @@ class TestQuery:
|
|||
|
||||
result = await document_rag.query(
|
||||
query=query_text,
|
||||
user="research_user",
|
||||
collection="ml_knowledge",
|
||||
doc_limit=25
|
||||
)
|
||||
|
|
@ -619,7 +612,7 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=10
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Unit test for DocumentRAG service parameter passing fix.
|
||||
Tests that user and collection parameters from the message are correctly
|
||||
Tests that the collection parameter from the message is correctly
|
||||
passed to the DocumentRag.query() method.
|
||||
"""
|
||||
|
||||
|
|
@ -16,13 +16,13 @@ class TestDocumentRagService:
|
|||
|
||||
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class):
|
||||
async def test_collection_parameter_passed_to_query(self, mock_document_rag_class):
|
||||
"""
|
||||
Test that user and collection from message are passed to DocumentRag.query().
|
||||
|
||||
This is a regression test for the bug where user/collection parameters
|
||||
were ignored, causing wrong collection names like 'd_trustgraph_default_384'
|
||||
instead of 'd_my_user_test_coll_1_384'.
|
||||
Test that collection from message is passed to DocumentRag.query().
|
||||
|
||||
This is a regression test for the bug where the collection parameter
|
||||
was ignored, causing wrong collection names like 'd_trustgraph_default_384'
|
||||
instead of one that reflects the requested collection.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
|
|
@ -30,17 +30,16 @@ class TestDocumentRagService:
|
|||
id="test-processor",
|
||||
doc_limit=10
|
||||
)
|
||||
|
||||
|
||||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with custom user/collection
|
||||
|
||||
# Setup message with custom collection
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = DocumentRagQuery(
|
||||
query="test query",
|
||||
user="my_user", # Custom user (not default "trustgraph")
|
||||
collection="test_coll_1", # Custom collection (not default "default")
|
||||
doc_limit=5
|
||||
)
|
||||
|
|
@ -64,7 +63,7 @@ class TestDocumentRagService:
|
|||
# Verify: DocumentRag.query was called with correct parameters
|
||||
mock_rag_instance.query.assert_called_once_with(
|
||||
"test query",
|
||||
user="my_user", # Must be from message, not hardcoded default
|
||||
workspace=ANY, # Workspace comes from flow.workspace (mock)
|
||||
collection="test_coll_1", # Must be from message, not hardcoded default
|
||||
doc_limit=5,
|
||||
explain_callback=ANY, # Explainability callback is always passed
|
||||
|
|
@ -103,7 +102,6 @@ class TestDocumentRagService:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = DocumentRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
doc_limit=10,
|
||||
streaming=False # Non-streaming mode
|
||||
|
|
|
|||
|
|
@ -78,14 +78,12 @@ class TestQuery:
|
|||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.entity_limit == 50 # Default value
|
||||
|
|
@ -101,7 +99,6 @@ class TestQuery:
|
|||
# Initialize Query with custom parameters
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="custom_user",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
entity_limit=100,
|
||||
|
|
@ -112,7 +109,6 @@ class TestQuery:
|
|||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.entity_limit == 100
|
||||
|
|
@ -133,7 +129,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -156,7 +151,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
|
@ -177,7 +171,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -201,7 +194,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -244,7 +236,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
entity_limit=25
|
||||
|
|
@ -269,7 +260,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -277,7 +267,7 @@ class TestQuery:
|
|||
result = await query.maybe_label("entity1")
|
||||
|
||||
assert result == "Entity One Label"
|
||||
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
|
||||
mock_cache.get.assert_called_once_with("test_collection:entity1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
|
|
@ -295,7 +285,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -307,13 +296,12 @@ class TestQuery:
|
|||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
g=""
|
||||
)
|
||||
|
||||
assert result == "Human Readable Label"
|
||||
cache_key = "test_user:test_collection:http://example.com/entity"
|
||||
cache_key = "test_collection:http://example.com/entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -330,7 +318,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -342,13 +329,12 @@ class TestQuery:
|
|||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
g=""
|
||||
)
|
||||
|
||||
assert result == "unlabeled_entity"
|
||||
cache_key = "test_user:test_collection:unlabeled_entity"
|
||||
cache_key = "test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -375,7 +361,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
|
|
@ -388,15 +373,15 @@ class TestQuery:
|
|||
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
|
||||
expected_subgraph = {
|
||||
|
|
@ -415,7 +400,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -435,7 +419,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
|
|
@ -455,7 +438,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
|
|
@ -493,7 +475,6 @@ class TestQuery:
|
|||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
|
|
@ -601,7 +582,6 @@ class TestQuery:
|
|||
try:
|
||||
response = await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15,
|
||||
|
|
|
|||
|
|
@ -120,7 +120,6 @@ class TestGraphRagServiceExplainTriples:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is quantum computing?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ class TestGraphRagService:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
|
|
@ -123,7 +122,6 @@ class TestGraphRagService:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
|
|
@ -190,7 +188,6 @@ class TestGraphRagService:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="Test query",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -286,11 +286,11 @@ class TestNLPQueryProcessor:
|
|||
}
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(config, "v1")
|
||||
await processor.on_schema_config("default", config, "v1")
|
||||
|
||||
# Assert
|
||||
assert "test_schema" in processor.schemas
|
||||
schema = processor.schemas["test_schema"]
|
||||
assert "test_schema" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["test_schema"]
|
||||
assert schema.name == "test_schema"
|
||||
assert schema.description == "Test schema"
|
||||
assert len(schema.fields) == 2
|
||||
|
|
@ -308,10 +308,10 @@ class TestNLPQueryProcessor:
|
|||
}
|
||||
|
||||
# Act
|
||||
await processor.on_schema_config(config, "v1")
|
||||
await processor.on_schema_config("default", config, "v1")
|
||||
|
||||
# Assert - bad schema should be ignored
|
||||
assert "bad_schema" not in processor.schemas
|
||||
assert "bad_schema" not in processor.schemas.get("default", {})
|
||||
|
||||
def test_processor_initialization(self, mock_pulsar_client):
|
||||
"""Test processor initialization with correct specifications"""
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ def service(mock_schemas):
|
|||
taskgroup=MagicMock(),
|
||||
id="test-processor"
|
||||
)
|
||||
service.schemas = mock_schemas
|
||||
service.schemas = {"default": dict(mock_schemas)}
|
||||
return service
|
||||
|
||||
|
||||
|
|
@ -109,6 +109,7 @@ def service(mock_schemas):
|
|||
def mock_flow():
|
||||
"""Create mock flow with prompt service"""
|
||||
flow = MagicMock()
|
||||
flow.workspace = "default"
|
||||
prompt_request_flow = AsyncMock()
|
||||
flow.return_value.request = prompt_request_flow
|
||||
return flow, prompt_request_flow
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ class TestStructuredQueryProcessor:
|
|||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me all customers from New York",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
|
||||
|
|
@ -110,7 +109,6 @@ class TestStructuredQueryProcessor:
|
|||
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
||||
assert objects_call_args.variables == {"state": "NY"}
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test document embeddings
|
||||
|
|
@ -80,7 +79,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings for a single chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -89,7 +87,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify insert was called once for the single chunk with its vector
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -99,14 +97,14 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
||||
"""Test storing document embeddings for multiple chunks"""
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('test_workspace', mock_message)
|
||||
|
||||
# Verify insert was called once per chunk with user/collection parameters
|
||||
# Verify insert was called once per chunk with workspace/collection parameters
|
||||
expected_calls = [
|
||||
# Chunk 1 - single vector
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_workspace', 'test_collection'),
|
||||
# Chunk 2 - single vector
|
||||
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
|
||||
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_workspace', 'test_collection'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
|
|
@ -122,7 +120,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -131,7 +128,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called for empty chunk
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -141,7 +138,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with None chunk_id"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -150,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Note: Implementation passes through None chunk_ids (only skips empty string "")
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -162,7 +158,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with mix of valid and empty chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_chunk = ChunkEmbeddings(
|
||||
|
|
@ -179,7 +174,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [valid_chunk, empty_chunk, another_valid]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify valid chunks were inserted, empty string chunk was skipped
|
||||
expected_calls = [
|
||||
|
|
@ -200,11 +195,10 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with empty chunks list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.chunks = []
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -214,7 +208,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings for chunk with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -223,7 +216,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called (no vectors to insert)
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -233,7 +226,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Each chunk has a single vector of different dimensions
|
||||
|
|
@ -251,7 +243,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk1, chunk2, chunk3]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension with user/collection parameters
|
||||
expected_calls = [
|
||||
|
|
@ -273,7 +265,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with Unicode content in chunk_id"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -282,7 +273,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify Unicode chunk_id was stored correctly with user/collection parameters
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -294,7 +285,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with long chunk_id"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a long chunk_id
|
||||
|
|
@ -305,7 +295,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify long chunk_id was inserted with user/collection parameters
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -317,7 +307,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with whitespace-only chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -326,7 +315,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify whitespace content was inserted (not filtered out) with user/collection parameters
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -343,25 +332,24 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
('test@domain.com', 'test-collection.v1'),
|
||||
]
|
||||
|
||||
for user, collection in test_cases:
|
||||
for workspace, collection in test_cases:
|
||||
processor.vecstore.reset_mock() # Reset mock for each test case
|
||||
|
||||
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = user
|
||||
message.metadata.collection = collection
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Test content",
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify insert was called with the correct user/collection
|
||||
|
||||
await processor.store_document_embeddings(workspace, message)
|
||||
|
||||
# Verify insert was called with the correct workspace/collection
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], "Test content", user, collection
|
||||
[0.1, 0.2, 0.3], "Test content", workspace, collection
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -370,7 +358,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
# Store embeddings for user1/collection1
|
||||
message1 = MagicMock()
|
||||
message1.metadata = MagicMock()
|
||||
message1.metadata.user = 'user1'
|
||||
message1.metadata.collection = 'collection1'
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk_id="User1 content",
|
||||
|
|
@ -381,7 +368,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
# Store embeddings for user2/collection2
|
||||
message2 = MagicMock()
|
||||
message2.metadata = MagicMock()
|
||||
message2.metadata.user = 'user2'
|
||||
message2.metadata.collection = 'collection2'
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk_id="User2 content",
|
||||
|
|
@ -389,8 +375,8 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message2.chunks = [chunk2]
|
||||
|
||||
await processor.store_document_embeddings(message1)
|
||||
await processor.store_document_embeddings(message2)
|
||||
await processor.store_document_embeddings('user1', message1)
|
||||
await processor.store_document_embeddings('user2', message2)
|
||||
|
||||
# Verify both calls were made with correct parameters
|
||||
expected_calls = [
|
||||
|
|
@ -411,18 +397,17 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with special characters in user/collection names"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'user@domain.com' # Email-like user
|
||||
message.metadata.collection = 'test-collection.v1' # Collection with special chars
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Special chars test",
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify the exact user/collection strings are passed (sanitization happens in DocVectors)
|
||||
|
||||
await processor.store_document_embeddings('user@domain.com', message)
|
||||
|
||||
# Verify the exact workspace/collection strings are passed (sanitization happens in DocVectors)
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test document embeddings
|
||||
|
|
@ -120,7 +119,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings for a single chunk"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -135,7 +133,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index name and operations (with dimension suffix)
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -185,7 +183,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test that writing to non-existent index creates it lazily"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -200,7 +197,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created with correct dimension
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -217,7 +214,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -229,7 +225,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for empty chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -239,7 +235,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with None chunk (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -251,7 +246,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for None chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -261,7 +256,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with chunk that decodes to empty string"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -273,7 +267,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for empty decoded chunk
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -283,7 +277,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Each chunk has a single vector of different dimensions
|
||||
|
|
@ -325,14 +318,13 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with empty chunks list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.chunks = []
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no operations were performed
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -343,7 +335,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings for chunk with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -355,7 +346,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called (no vectors to insert)
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -365,7 +356,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test that lazy creation happens when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -380,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
|
@ -390,7 +380,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test that lazy creation works correctly"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -405,7 +394,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created and used
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
|
@ -416,7 +405,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with Unicode content"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
|
|
@ -430,7 +418,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify Unicode content was properly decoded and stored
|
||||
call_args = mock_index.upsert.call_args
|
||||
|
|
@ -442,7 +430,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
"""Test storing document embeddings with large document chunks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a large document chunk
|
||||
|
|
@ -458,7 +445,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
await processor.store_document_embeddings('test_user', message)
|
||||
|
||||
# Verify large content was stored
|
||||
call_args = mock_index.upsert.call_args
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with chunks and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
|
|
@ -94,7 +93,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked (with dimension suffix)
|
||||
|
|
@ -138,7 +137,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with multiple chunks
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
|
|
@ -152,7 +150,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per chunk)
|
||||
|
|
@ -198,7 +196,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with multiple chunks, each having a single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
|
|
@ -216,7 +213,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('vector_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per chunk)
|
||||
|
|
@ -255,7 +252,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with empty chunk_id
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_chunk_empty = MagicMock()
|
||||
|
|
@ -265,7 +261,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk_empty]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty chunk_ids
|
||||
|
|
@ -298,7 +294,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
|
|
@ -308,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('new_user', mock_message)
|
||||
|
||||
# Assert - collection should be lazily created
|
||||
expected_collection = 'd_new_user_new_collection_5' # 5 dimensions
|
||||
|
|
@ -350,7 +345,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'error_user'
|
||||
mock_message.metadata.collection = 'error_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
|
|
@ -361,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act & Assert - should propagate the creation error
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('error_user', mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
|
|
@ -388,7 +382,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create first mock message
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'cache_user'
|
||||
mock_message1.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
|
|
@ -398,7 +391,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message1.chunks = [mock_chunk1]
|
||||
|
||||
# First call
|
||||
await processor.store_document_embeddings(mock_message1)
|
||||
await processor.store_document_embeddings('cache_user', mock_message1)
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
|
@ -406,7 +399,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create second mock message with same dimensions
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'cache_user'
|
||||
mock_message2.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
|
|
@ -416,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message2.chunks = [mock_chunk2]
|
||||
|
||||
# Act - Second call with same collection
|
||||
await processor.store_document_embeddings(mock_message2)
|
||||
await processor.store_document_embeddings('cache_user', mock_message2)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
||||
|
|
@ -452,7 +444,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with chunks of different dimensions
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'dim_user'
|
||||
mock_message.metadata.collection = 'dim_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
|
|
@ -466,7 +457,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('dim_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should check existence of DIFFERENT collections for each dimension
|
||||
|
|
@ -526,7 +517,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with URI-style chunk_id
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'uri_user'
|
||||
mock_message.metadata.collection = 'uri_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
|
|
@ -536,7 +526,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
await processor.store_document_embeddings('uri_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify the chunk_id was stored correctly
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entities with embeddings
|
||||
|
|
@ -80,7 +79,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for a single entity"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -89,7 +87,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify insert was called once with the full vector
|
||||
processor.vecstore.insert.assert_called_once()
|
||||
|
|
@ -102,14 +100,14 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
||||
"""Test storing graph embeddings for multiple entities"""
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('test_workspace', mock_message)
|
||||
|
||||
# Verify insert was called once per entity with user/collection parameters
|
||||
# Verify insert was called once per entity with workspace/collection parameters
|
||||
expected_calls = [
|
||||
# Entity 1 - single vector
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_workspace', 'test_collection'),
|
||||
# Entity 2 - single vector
|
||||
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
|
||||
([0.7, 0.8, 0.9], 'literal entity', 'test_workspace', 'test_collection'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
|
|
@ -125,7 +123,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -134,7 +131,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called for empty entity
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -144,7 +141,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with None entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -153,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called for None entity
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -163,7 +159,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with mix of valid and invalid entities"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_entity = EntityEmbeddings(
|
||||
|
|
@ -183,7 +178,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [valid_entity, empty_entity, none_entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify only valid entity was inserted with user/collection/chunk_id parameters
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
|
|
@ -196,11 +191,10 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with empty entities list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.entities = []
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -210,7 +204,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for entity with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -219,7 +212,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no insert was called (no vectors to insert)
|
||||
processor.vecstore.insert.assert_not_called()
|
||||
|
|
@ -229,7 +222,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Each entity has a single vector of different dimensions
|
||||
|
|
@ -247,7 +239,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [entity1, entity2, entity3]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension
|
||||
expected_calls = [
|
||||
|
|
@ -267,7 +259,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for both URI and literal entities"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
uri_entity = EntityEmbeddings(
|
||||
|
|
@ -280,7 +271,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.entities = [uri_entity, literal_entity]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify both entities were inserted
|
||||
expected_calls = [
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entity embeddings (each entity has a single vector)
|
||||
|
|
@ -124,7 +123,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for a single entity"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -139,7 +137,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index name and operations (with dimension suffix)
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -189,7 +187,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test that writing to non-existent index creates it lazily"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -204,7 +201,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created with correct dimension
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
|
|
@ -221,7 +218,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -233,7 +229,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for empty entity
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -243,7 +239,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with None entity value (should be skipped)"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -255,7 +250,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called for None entity
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -265,7 +260,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Each entity has a single vector of different dimensions
|
||||
|
|
@ -288,7 +282,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify different indexes were used for different dimensions
|
||||
index_calls = processor.pinecone.Index.call_args_list
|
||||
|
|
@ -307,14 +301,13 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings with empty entities list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.entities = []
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no operations were performed
|
||||
processor.pinecone.Index.assert_not_called()
|
||||
|
|
@ -325,7 +318,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test storing graph embeddings for entity with no vectors"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -337,7 +329,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify no upsert was called (no vectors to insert)
|
||||
mock_index.upsert.assert_not_called()
|
||||
|
|
@ -347,7 +339,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test that lazy creation happens when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -362,7 +353,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
|
@ -372,7 +363,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
"""Test that lazy creation works correctly"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
|
|
@ -387,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
await processor.store_graph_embeddings('test_user', message)
|
||||
|
||||
# Verify index was created and used
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with entities and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
|
|
@ -75,7 +74,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.entities = [mock_entity]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('test_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked (with dimension suffix)
|
||||
|
|
@ -118,7 +117,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with multiple entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
|
|
@ -134,7 +132,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('multi_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per entity)
|
||||
|
|
@ -179,7 +177,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with three entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
|
|
@ -200,7 +197,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('vector_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per entity)
|
||||
|
|
@ -238,7 +235,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with empty entity value
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity_empty = MagicMock()
|
||||
|
|
@ -253,7 +249,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
await processor.store_graph_embeddings('empty_user', mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty entities
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Tests for Memgraph user/collection isolation in storage service
|
||||
Tests for Memgraph workspace/collection isolation in storage service.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -8,47 +8,45 @@ from unittest.mock import MagicMock, patch
|
|||
from trustgraph.storage.triples.memgraph.write import Processor
|
||||
|
||||
|
||||
class TestMemgraphUserCollectionIsolation:
|
||||
"""Test cases for Memgraph storage service with user/collection isolation"""
|
||||
class TestMemgraphWorkspaceCollectionIsolation:
|
||||
"""Test cases for Memgraph storage service with workspace/collection isolation"""
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
|
||||
"""Test that storage creates both legacy and user/collection indexes"""
|
||||
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test that storage creates both legacy and workspace/collection indexes"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total)
|
||||
|
||||
# 4 legacy + 4 workspace/collection = 8 total
|
||||
assert mock_session.run.call_count == 8
|
||||
|
||||
# Check some specific index creation calls
|
||||
|
||||
expected_calls = [
|
||||
"CREATE INDEX ON :Node",
|
||||
"CREATE INDEX ON :Node(uri)",
|
||||
"CREATE INDEX ON :Literal",
|
||||
"CREATE INDEX ON :Literal(value)",
|
||||
"CREATE INDEX ON :Node(user)",
|
||||
"CREATE INDEX ON :Node(workspace)",
|
||||
"CREATE INDEX ON :Node(collection)",
|
||||
"CREATE INDEX ON :Literal(user)",
|
||||
"CREATE INDEX ON :Literal(workspace)",
|
||||
"CREATE INDEX ON :Literal(collection)"
|
||||
]
|
||||
|
||||
|
||||
for expected_call in expected_calls:
|
||||
mock_session.run.assert_any_call(expected_call)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_user_collection(self, mock_graph_db):
|
||||
"""Test that store_triples includes user/collection in all operations"""
|
||||
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test that store_triples includes workspace/collection in all operations"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
|
|
@ -58,45 +56,39 @@ class TestMemgraphUserCollectionIsolation:
|
|||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Create mock triple with URI object
|
||||
from trustgraph.schema import IRI
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "http://example.com/object"
|
||||
triple.o.is_uri = True
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = IRI
|
||||
triple.o.iri = "http://example.com/object"
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("test_workspace", mock_message)
|
||||
|
||||
# Verify user/collection parameters were passed to all operations
|
||||
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
|
||||
# create_node (subject), create_node (object), relate_node = 3 calls
|
||||
assert mock_driver.execute_query.call_count == 3
|
||||
|
||||
# Check that user and collection were included in all calls
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
assert 'user' in call_kwargs
|
||||
assert 'collection' in call_kwargs
|
||||
assert call_kwargs['user'] == "test_user"
|
||||
assert call_kwargs['collection'] == "test_collection"
|
||||
for c in mock_driver.execute_query.call_args_list:
|
||||
kwargs = c.kwargs
|
||||
assert kwargs['workspace'] == "test_workspace"
|
||||
assert kwargs['collection'] == "test_collection"
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
|
||||
"""Test that defaults are used when user/collection not provided in metadata"""
|
||||
async def test_store_triples_with_default_collection(self, mock_graph_db):
|
||||
"""Test that default collection is used when not provided in metadata"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
|
|
@ -106,157 +98,151 @@ class TestMemgraphUserCollectionIsolation:
|
|||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Create mock triple
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = "literal_value"
|
||||
triple.o.is_uri = False
|
||||
|
||||
# Create mock message without user/collection metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = None
|
||||
mock_message.metadata.collection = None
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("default", mock_message)
|
||||
|
||||
# Verify defaults were used
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
assert call_kwargs['user'] == "default"
|
||||
assert call_kwargs['collection'] == "default"
|
||||
for c in mock_driver.execute_query.call_args_list:
|
||||
kwargs = c.kwargs
|
||||
assert kwargs['workspace'] == "default"
|
||||
assert kwargs['collection'] == "default"
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_create_node_includes_user_collection(self, mock_graph_db):
|
||||
"""Test that create_node includes user/collection properties"""
|
||||
def test_create_node_includes_workspace_collection(self, mock_graph_db):
|
||||
"""Test that create_node includes workspace/collection properties"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
processor.create_node("http://example.com/node", "test_user", "test_collection")
|
||||
|
||||
|
||||
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri="http://example.com/node",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="memgraph"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_create_literal_includes_user_collection(self, mock_graph_db):
|
||||
"""Test that create_literal includes user/collection properties"""
|
||||
def test_create_literal_includes_workspace_collection(self, mock_graph_db):
|
||||
"""Test that create_literal includes workspace/collection properties"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
processor.create_literal("test_value", "test_user", "test_collection")
|
||||
|
||||
|
||||
processor.create_literal("test_value", "test_workspace", "test_collection")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="test_value",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="memgraph"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_relate_node_includes_user_collection(self, mock_graph_db):
|
||||
"""Test that relate_node includes user/collection properties"""
|
||||
def test_relate_node_includes_workspace_collection(self, mock_graph_db):
|
||||
"""Test that relate_node includes workspace/collection properties"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
processor.relate_node(
|
||||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"http://example.com/predicate",
|
||||
"http://example.com/object",
|
||||
"test_user",
|
||||
"test_workspace",
|
||||
"test_collection"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="http://example.com/object",
|
||||
uri="http://example.com/predicate",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="memgraph"
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
def test_relate_literal_includes_user_collection(self, mock_graph_db):
|
||||
"""Test that relate_literal includes user/collection properties"""
|
||||
def test_relate_literal_includes_workspace_collection(self, mock_graph_db):
|
||||
"""Test that relate_literal includes workspace/collection properties"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 0
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
processor.relate_literal(
|
||||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"literal_value",
|
||||
"test_user",
|
||||
"test_workspace",
|
||||
"test_collection"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="literal_value",
|
||||
uri="http://example.com/predicate",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="memgraph"
|
||||
)
|
||||
|
|
@ -264,20 +250,15 @@ class TestMemgraphUserCollectionIsolation:
|
|||
def test_add_args_includes_memgraph_parameters(self):
|
||||
"""Test that add_args properly configures Memgraph-specific parameters"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
|
||||
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once()
|
||||
|
||||
# Verify our specific arguments were added with Memgraph defaults
|
||||
|
||||
args = parser.parse_args([])
|
||||
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'bolt://memgraph:7687'
|
||||
assert hasattr(args, 'username')
|
||||
|
|
@ -288,19 +269,18 @@ class TestMemgraphUserCollectionIsolation:
|
|||
assert args.database == 'memgraph'
|
||||
|
||||
|
||||
class TestMemgraphUserCollectionRegression:
|
||||
"""Regression tests to ensure user/collection isolation prevents data leakage"""
|
||||
class TestMemgraphWorkspaceCollectionRegression:
|
||||
"""Regression tests to ensure workspace/collection isolation prevents data leakage"""
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
|
||||
"""Regression test: Ensure users cannot access each other's data"""
|
||||
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
|
||||
"""Regression test: Ensure workspaces cannot access each other's data"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
|
|
@ -310,60 +290,55 @@ class TestMemgraphUserCollectionRegression:
|
|||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Store data for user1
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "user1_data"
|
||||
triple.o.is_uri = False
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = "ws1_data"
|
||||
|
||||
message_user1 = MagicMock()
|
||||
message_user1.triples = [triple]
|
||||
message_user1.metadata.user = "user1"
|
||||
message_user1.metadata.collection = "collection1"
|
||||
message_ws1 = MagicMock()
|
||||
message_ws1.triples = [triple]
|
||||
message_ws1.metadata.collection = "collection1"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples("workspace1", message_ws1)
|
||||
|
||||
# Verify that all storage operations included user1/collection1 parameters
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
if 'user' in call_kwargs:
|
||||
assert call_kwargs['user'] == "user1"
|
||||
assert call_kwargs['collection'] == "collection1"
|
||||
for c in mock_driver.execute_query.call_args_list:
|
||||
kwargs = c.kwargs
|
||||
if 'workspace' in kwargs:
|
||||
assert kwargs['workspace'] == "workspace1"
|
||||
assert kwargs['collection'] == "collection1"
|
||||
|
||||
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_same_uri_different_users(self, mock_graph_db):
|
||||
"""Regression test: Same URI can exist for different users without conflict"""
|
||||
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
|
||||
"""Regression test: Same URI can exist in different workspaces without conflict"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock execute_query response
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Same URI for different users should create separate nodes
|
||||
processor.create_node("http://example.com/same-uri", "user1", "collection1")
|
||||
processor.create_node("http://example.com/same-uri", "user2", "collection2")
|
||||
|
||||
# Verify both calls were made with different user/collection parameters
|
||||
calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls
|
||||
|
||||
call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1]
|
||||
call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1]
|
||||
|
||||
assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1"
|
||||
assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2"
|
||||
|
||||
# Both should have the same URI but different user/collection
|
||||
assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri"
|
||||
|
||||
processor.create_node("http://example.com/same-uri", "workspace1", "collection1")
|
||||
processor.create_node("http://example.com/same-uri", "workspace2", "collection2")
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list[-2:]
|
||||
|
||||
k1 = calls[0].kwargs
|
||||
k2 = calls[1].kwargs
|
||||
|
||||
assert k1['workspace'] == "workspace1" and k1['collection'] == "collection1"
|
||||
assert k2['workspace'] == "workspace2" and k2['collection'] == "collection2"
|
||||
|
||||
assert k1['uri'] == k2['uri'] == "http://example.com/same-uri"
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Tests for Neo4j user/collection isolation in triples storage and query
|
||||
Tests for Neo4j workspace/collection isolation in triples storage and query.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -11,468 +11,406 @@ from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
|
|||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
|
||||
class TestNeo4jUserCollectionIsolation:
|
||||
"""Test cases for Neo4j user/collection isolation functionality"""
|
||||
class TestNeo4jWorkspaceCollectionIsolation:
|
||||
"""Test cases for Neo4j workspace/collection isolation functionality"""
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
|
||||
"""Test that storage service creates compound indexes for user/collection"""
|
||||
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test that storage service creates compound indexes for workspace/collection"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Verify both legacy and new compound indexes are created
|
||||
|
||||
expected_indexes = [
|
||||
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
|
||||
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
|
||||
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
|
||||
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
|
||||
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
|
||||
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
|
||||
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
|
||||
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
|
||||
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
|
||||
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
|
||||
]
|
||||
|
||||
# Check that all expected indexes were created
|
||||
|
||||
for expected_query in expected_indexes:
|
||||
mock_session.run.assert_any_call(expected_query)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_user_collection(self, mock_graph_db):
|
||||
"""Test that triples are stored with user/collection properties"""
|
||||
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
|
||||
"""Test that triples are stored with workspace/collection properties"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create test message with user/collection metadata
|
||||
metadata = Metadata(
|
||||
id="test-id",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
metadata = Metadata(id="test-id", collection="test_collection")
|
||||
|
||||
triple = Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal_value")
|
||||
)
|
||||
|
||||
message = Triples(
|
||||
metadata=metadata,
|
||||
triples=[triple]
|
||||
)
|
||||
|
||||
# Mock execute_query to return summaries
|
||||
|
||||
message = Triples(metadata=metadata, triples=[triple])
|
||||
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify nodes and relationships were created with user/collection properties
|
||||
await processor.store_triples("test_workspace", message)
|
||||
|
||||
expected_calls = [
|
||||
call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri="http://example.com/subject",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
),
|
||||
call(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="literal_value",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
),
|
||||
call(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="literal_value",
|
||||
uri="http://example.com/predicate",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
for expected_call in expected_calls:
|
||||
mock_driver.execute_query.assert_any_call(*expected_call.args, **expected_call.kwargs)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
|
||||
"""Test that default user/collection are used when not provided"""
|
||||
async def test_store_triples_with_default_collection(self, mock_graph_db):
|
||||
"""Test that default collection is used when not provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create test message without user/collection
|
||||
|
||||
metadata = Metadata(id="test-id")
|
||||
|
||||
|
||||
triple = Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=IRI, iri="http://example.com/object")
|
||||
)
|
||||
|
||||
message = Triples(
|
||||
metadata=metadata,
|
||||
triples=[triple]
|
||||
)
|
||||
|
||||
# Mock execute_query
|
||||
|
||||
message = Triples(metadata=metadata, triples=[triple])
|
||||
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify defaults were used
|
||||
await processor.store_triples("default", message)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri="http://example.com/subject",
|
||||
user="default",
|
||||
workspace="default",
|
||||
collection="default",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_filters_by_user_collection(self, mock_graph_db):
|
||||
"""Test that query service filters results by user/collection"""
|
||||
async def test_query_triples_filters_by_workspace_collection(self, mock_graph_db):
|
||||
"""Test that query service filters results by workspace/collection"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
|
||||
processor = QueryProcessor(taskgroup=MagicMock())
|
||||
|
||||
# Create test query
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None
|
||||
)
|
||||
|
||||
# Mock query results
|
||||
|
||||
mock_records = [
|
||||
MagicMock(data=lambda: {"dest": "http://example.com/object1"}),
|
||||
MagicMock(data=lambda: {"dest": "literal_value"})
|
||||
]
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock())
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify queries include user/collection filters
|
||||
|
||||
await processor.query_triples("test_workspace", query)
|
||||
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN dest.value as dest"
|
||||
)
|
||||
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN dest.uri as dest"
|
||||
)
|
||||
|
||||
# Check that queries were executed with user/collection parameters
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
assert any(
|
||||
expected_literal_query in str(call) and
|
||||
"user='test_user'" in str(call) and
|
||||
"collection='test_collection'" in str(call)
|
||||
for call in calls
|
||||
expected_literal_query in str(c) and
|
||||
"workspace='test_workspace'" in str(c) and
|
||||
"collection='test_collection'" in str(c)
|
||||
for c in calls
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_triples_with_default_user_collection(self, mock_graph_db):
|
||||
"""Test that query service uses defaults when user/collection not provided"""
|
||||
async def test_query_triples_with_default_collection(self, mock_graph_db):
|
||||
"""Test that query service uses default collection when not provided"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
|
||||
processor = QueryProcessor(taskgroup=MagicMock())
|
||||
|
||||
# Create test query without user/collection
|
||||
query = TriplesQueryRequest(
|
||||
s=None,
|
||||
p=None,
|
||||
o=None
|
||||
)
|
||||
|
||||
# Mock empty results
|
||||
|
||||
query = TriplesQueryRequest(s=None, p=None, o=None)
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify defaults were used in queries
|
||||
|
||||
await processor.query_triples("default", query)
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
assert any(
|
||||
"user='default'" in str(call) and "collection='default'" in str(call)
|
||||
for call in calls
|
||||
"workspace='default'" in str(c) and "collection='default'" in str(c)
|
||||
for c in calls
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_isolation_between_users(self, mock_graph_db):
|
||||
"""Test that data from different users is properly isolated"""
|
||||
async def test_data_isolation_between_workspaces(self, mock_graph_db):
|
||||
"""Test that data from different workspaces is properly isolated"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create messages for different users
|
||||
message_user1 = Triples(
|
||||
metadata=Metadata(user="user1", collection="coll1"),
|
||||
|
||||
message_ws1 = Triples(
|
||||
metadata=Metadata(collection="coll1"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/user1/subject"),
|
||||
s=Term(type=IRI, iri="http://example.com/ws1/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="user1_data")
|
||||
o=Term(type=LITERAL, value="ws1_data")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
message_user2 = Triples(
|
||||
metadata=Metadata(user="user2", collection="coll2"),
|
||||
|
||||
message_ws2 = Triples(
|
||||
metadata=Metadata(collection="coll2"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/user2/subject"),
|
||||
s=Term(type=IRI, iri="http://example.com/ws2/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="user2_data")
|
||||
o=Term(type=LITERAL, value="ws2_data")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Mock execute_query
|
||||
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
# Store data for both users
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Verify user1 data was stored with user1/coll1
|
||||
await processor.store_triples("workspace1", message_ws1)
|
||||
await processor.store_triples("workspace2", message_ws2)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
value="user1_data",
|
||||
user="user1",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="ws1_data",
|
||||
workspace="workspace1",
|
||||
collection="coll1",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
# Verify user2 data was stored with user2/coll2
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
value="user2_data",
|
||||
user="user2",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="ws2_data",
|
||||
workspace="workspace2",
|
||||
collection="coll2",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_wildcard_query_respects_user_collection(self, mock_graph_db):
|
||||
"""Test that wildcard queries still filter by user/collection"""
|
||||
async def test_wildcard_query_respects_workspace_collection(self, mock_graph_db):
|
||||
"""Test that wildcard queries still filter by workspace/collection"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
|
||||
processor = QueryProcessor(taskgroup=MagicMock())
|
||||
|
||||
# Create wildcard query (all nulls) with user/collection
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
o=None
|
||||
s=None, p=None, o=None,
|
||||
)
|
||||
|
||||
# Mock results
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify wildcard queries include user/collection filters
|
||||
|
||||
await processor.query_triples("test_workspace", query)
|
||||
|
||||
wildcard_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
|
||||
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
|
||||
"(dest:Literal {workspace: $workspace, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
|
||||
)
|
||||
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
assert any(
|
||||
wildcard_query in str(call) and
|
||||
"user='test_user'" in str(call) and
|
||||
"collection='test_collection'" in str(call)
|
||||
for call in calls
|
||||
wildcard_query in str(c) and
|
||||
"workspace='test_workspace'" in str(c) and
|
||||
"collection='test_collection'" in str(c)
|
||||
for c in calls
|
||||
)
|
||||
|
||||
def test_add_args_includes_neo4j_parameters(self):
|
||||
"""Test that add_args includes Neo4j-specific parameters"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
|
||||
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
|
||||
StorageProcessor.add_args(parser)
|
||||
|
||||
|
||||
args = parser.parse_args([])
|
||||
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert hasattr(args, 'username')
|
||||
assert hasattr(args, 'password')
|
||||
assert hasattr(args, 'database')
|
||||
|
||||
# Check defaults
|
||||
|
||||
assert args.graph_host == 'bolt://neo4j:7687'
|
||||
assert args.username == 'neo4j'
|
||||
assert args.password == 'password'
|
||||
assert args.database == 'neo4j'
|
||||
|
||||
|
||||
class TestNeo4jUserCollectionRegression:
|
||||
"""Regression tests to ensure user/collection isolation prevents data leaks"""
|
||||
|
||||
class TestNeo4jWorkspaceCollectionRegression:
|
||||
"""Regression tests to ensure workspace/collection isolation prevents data leaks"""
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
|
||||
"""
|
||||
Regression test: Ensure user1 cannot access user2's data
|
||||
|
||||
This test guards against the bug where all users shared the same
|
||||
Neo4j graph space, causing data contamination between users.
|
||||
Regression test: Ensure workspace1 cannot access workspace2's data.
|
||||
|
||||
Guards against a bug where all data shared the same Neo4j graph
|
||||
space, causing data contamination between workspaces.
|
||||
"""
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
|
||||
|
||||
processor = QueryProcessor(taskgroup=MagicMock())
|
||||
|
||||
# User1 queries for all triples
|
||||
query_user1 = TriplesQueryRequest(
|
||||
user="user1",
|
||||
|
||||
query_ws1 = TriplesQueryRequest(
|
||||
collection="collection1",
|
||||
s=None, p=None, o=None
|
||||
)
|
||||
|
||||
# Mock that the database has data but none matching user1/collection1
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
result = await processor.query_triples(query_user1)
|
||||
|
||||
# Verify empty results (user1 cannot see other users' data)
|
||||
|
||||
result = await processor.query_triples("workspace1", query_ws1)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
# Verify the query included user/collection filters
|
||||
|
||||
calls = mock_driver.execute_query.call_args_list
|
||||
for call in calls:
|
||||
query_str = str(call)
|
||||
for c in calls:
|
||||
query_str = str(c)
|
||||
if "MATCH" in query_str:
|
||||
assert "user: $user" in query_str or "user='user1'" in query_str
|
||||
assert "workspace: $workspace" in query_str or "workspace='workspace1'" in query_str
|
||||
assert "collection: $collection" in query_str or "collection='collection1'" in query_str
|
||||
|
||||
|
||||
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
async def test_regression_same_uri_different_users(self, mock_graph_db):
|
||||
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
|
||||
"""
|
||||
Regression test: Same URI in different user contexts should create separate nodes
|
||||
|
||||
This ensures that http://example.com/entity for user1 is completely separate
|
||||
from http://example.com/entity for user2.
|
||||
Regression test: Same URI in different workspace contexts should create separate nodes.
|
||||
|
||||
Ensures http://example.com/entity in workspace1 is completely
|
||||
separate from the same URI in workspace2.
|
||||
"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_driver = MagicMock()
|
||||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
processor = StorageProcessor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Same URI for different users
|
||||
|
||||
shared_uri = "http://example.com/shared_entity"
|
||||
|
||||
message_user1 = Triples(
|
||||
metadata=Metadata(user="user1", collection="coll1"),
|
||||
|
||||
message_ws1 = Triples(
|
||||
metadata=Metadata(collection="coll1"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri=shared_uri),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="user1_value")
|
||||
o=Term(type=LITERAL, value="ws1_value")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
message_user2 = Triples(
|
||||
metadata=Metadata(user="user2", collection="coll2"),
|
||||
|
||||
message_ws2 = Triples(
|
||||
metadata=Metadata(collection="coll2"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri=shared_uri),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="user2_value")
|
||||
o=Term(type=LITERAL, value="ws2_value")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Mock execute_query
|
||||
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Verify two separate nodes were created with same URI but different user/collection
|
||||
user1_node_call = call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
await processor.store_triples("workspace1", message_ws1)
|
||||
await processor.store_triples("workspace2", message_ws2)
|
||||
|
||||
ws1_node_call = call(
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri=shared_uri,
|
||||
user="user1",
|
||||
workspace="workspace1",
|
||||
collection="coll1",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
user2_node_call = call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
|
||||
ws2_node_call = call(
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri=shared_uri,
|
||||
user="user2",
|
||||
workspace="workspace2",
|
||||
collection="coll2",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True)
|
||||
|
||||
mock_driver.execute_query.assert_has_calls([ws1_node_call, ws2_node_call], any_order=True)
|
||||
|
|
@ -1,3 +1,12 @@
|
|||
|
||||
def _flow_mock(workspace):
|
||||
"""Build a mock flow object that is callable and exposes .workspace."""
|
||||
from unittest.mock import MagicMock
|
||||
f = MagicMock()
|
||||
f.workspace = workspace
|
||||
return f
|
||||
|
||||
|
||||
"""
|
||||
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
|
||||
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
|
||||
|
|
@ -92,13 +101,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
collection_name = processor.get_collection_name(
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
schema_name="customer_data",
|
||||
dimension=384
|
||||
)
|
||||
|
||||
assert collection_name == "rows_test_user_test_collection_customer_data_384"
|
||||
assert collection_name == "rows_test_workspace_test_collection_customer_data_384"
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
|
||||
|
|
@ -185,11 +194,10 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('test_workspace', 'test_collection')] = {}
|
||||
|
||||
# Create embeddings message
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
|
|
@ -210,14 +218,14 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
|
||||
assert upsert_call_args[1]['collection_name'] == 'rows_test_workspace_test_collection_customers_3'
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
|
|
@ -243,10 +251,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('test_workspace', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
|
|
@ -267,7 +274,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
|
||||
|
||||
# Should be called once for the single embedding
|
||||
assert mock_qdrant_instance.upsert.call_count == 1
|
||||
|
|
@ -287,10 +294,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
processor.known_collections[('test_workspace', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
|
|
@ -311,7 +317,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
|
||||
|
||||
# Should not call upsert for empty vectors
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
|
@ -334,7 +340,6 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# No collections registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'unknown_user'
|
||||
metadata.collection = 'unknown_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
|
|
@ -354,7 +359,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
|
||||
|
||||
# Should not call upsert for unknown collection
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
|
@ -368,11 +373,11 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock collections list
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
|
||||
mock_coll1.name = 'rows_test_workspace_test_collection_schema1_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
|
||||
mock_coll2.name = 'rows_test_workspace_test_collection_schema2_384'
|
||||
mock_coll3 = MagicMock()
|
||||
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
|
||||
mock_coll3.name = 'rows_other_workspace_other_collection_schema_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
|
||||
|
|
@ -386,15 +391,15 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
|
||||
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_user', 'test_collection')
|
||||
await processor.delete_collection('test_workspace', 'test_collection')
|
||||
|
||||
# Should delete only the matching collections
|
||||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
|
|
@ -404,9 +409,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance = MagicMock()
|
||||
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
|
||||
mock_coll1.name = 'rows_test_workspace_test_collection_customers_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
|
||||
mock_coll2.name = 'rows_test_workspace_test_collection_orders_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2]
|
||||
|
|
@ -422,13 +427,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
await processor.delete_collection_schema(
|
||||
'test_user', 'test_collection', 'customers'
|
||||
'test_workspace', 'test_collection', 'customers'
|
||||
)
|
||||
|
||||
# Should only delete the customers schema collection
|
||||
mock_qdrant_instance.delete_collection.assert_called_once()
|
||||
call_args = mock_qdrant_instance.delete_collection.call_args[0]
|
||||
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
|
||||
assert call_args[0] == 'rows_test_workspace_test_collection_customers_384'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
|
|||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
|
||||
|
||||
class _MockFlowDefault:
|
||||
"""Mock Flow with default workspace for testing."""
|
||||
workspace = "default"
|
||||
name = "default"
|
||||
id = "test-processor"
|
||||
|
||||
|
||||
mock_flow_default = _MockFlowDefault()
|
||||
|
||||
class TestRowsCassandraStorageLogic:
|
||||
"""Test business logic for unified table implementation"""
|
||||
|
||||
|
|
@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic:
|
|||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
await processor.on_schema_config("default", config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert "customer_records" in processor.schemas["default"]
|
||||
schema = processor.schemas["default"]["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
|
@ -165,16 +176,18 @@ class TestRowsCassandraStorageLogic:
|
|||
"""Test that row processing stores data as map<text, text>"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.tables_initialized = {"default"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
|
@ -191,7 +204,6 @@ class TestRowsCassandraStorageLogic:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
schema_name="test_schema",
|
||||
|
|
@ -205,7 +217,7 @@ class TestRowsCassandraStorageLogic:
|
|||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify insert was executed
|
||||
mock_async_execute.assert_called()
|
||||
|
|
@ -214,7 +226,7 @@ class TestRowsCassandraStorageLogic:
|
|||
values = insert_call[0][2]
|
||||
|
||||
# Verify using unified rows table
|
||||
assert "INSERT INTO test_user.rows" in insert_cql
|
||||
assert "INSERT INTO default.rows" in insert_cql
|
||||
|
||||
# Values should be: (collection, schema_name, index_name, index_value, data, source)
|
||||
assert values[0] == "test_collection" # collection
|
||||
|
|
@ -230,16 +242,18 @@ class TestRowsCassandraStorageLogic:
|
|||
"""Test that row is written once per indexed field"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.tables_initialized = {"default"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
|
@ -255,7 +269,6 @@ class TestRowsCassandraStorageLogic:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
),
|
||||
schema_name="multi_index_schema",
|
||||
|
|
@ -267,7 +280,7 @@ class TestRowsCassandraStorageLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 inserts (one per indexed field: id, category, status)
|
||||
assert mock_async_execute.call_count == 3
|
||||
|
|
@ -290,15 +303,17 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.tables_initialized = {"default"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
|
@ -315,7 +330,6 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
|
|
@ -331,7 +345,7 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Should have 3 inserts (one per row, one index per row since only primary key)
|
||||
assert mock_async_execute.call_count == 3
|
||||
|
|
@ -349,12 +363,14 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
"default": {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.tables_initialized = {"default"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
|
@ -369,7 +385,6 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
|
|
@ -381,7 +396,7 @@ class TestRowsCassandraStorageBatchLogic:
|
|||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
await processor.on_object(msg, None, mock_flow_default)
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
@ -446,19 +461,21 @@ class TestPartitionRegistration:
|
|||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
"default": {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
|
||||
|
||||
# Should have 2 inserts (one per index: id, category)
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
|
@ -473,7 +490,7 @@ class TestPartitionRegistration:
|
|||
processor.session = MagicMock()
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
|
||||
|
||||
# Should not execute any CQL since already registered
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -102,11 +102,10 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify KnowledgeGraph was called with auth parameters
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -129,11 +128,10 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user2'
|
||||
mock_message.metadata.collection = 'collection2'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user2', mock_message)
|
||||
|
||||
# Verify KnowledgeGraph was called without auth parameters
|
||||
mock_kg_class.assert_called_once_with(
|
||||
|
|
@ -154,16 +152,15 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call should create TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second call with same table should reuse TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -205,11 +202,10 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
|
|
@ -234,11 +230,10 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
|
|
@ -255,12 +250,11 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
|
@ -361,21 +355,19 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
# First message with table1
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'user1'
|
||||
mock_message1.metadata.collection = 'collection1'
|
||||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples(mock_message1)
|
||||
await processor.store_triples('user1', mock_message1)
|
||||
assert processor.table == 'user1'
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'user2'
|
||||
mock_message2.metadata.collection = 'collection2'
|
||||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples(mock_message2)
|
||||
await processor.store_triples('user2', mock_message2)
|
||||
assert processor.table == 'user2'
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
|
|
@ -407,11 +399,10 @@ class TestCassandraStorageProcessor:
|
|||
triple.g = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
|
|
@ -440,12 +431,11 @@ class TestCassandraStorageProcessor:
|
|||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('new_user', mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
|
|
@ -468,11 +458,10 @@ class TestCassandraPerformanceOptimizations:
|
|||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify KnowledgeGraph instance uses legacy mode
|
||||
assert mock_tg_instance is not None
|
||||
|
|
@ -489,11 +478,10 @@ class TestCassandraPerformanceOptimizations:
|
|||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify KnowledgeGraph instance is in optimized mode
|
||||
assert mock_tg_instance is not None
|
||||
|
|
@ -523,11 +511,10 @@ class TestCassandraPerformanceOptimizations:
|
|||
triple.g = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a test triple
|
||||
|
|
@ -89,13 +88,13 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri, 'test_user', 'test_collection')
|
||||
processor.create_node(test_uri, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
params={
|
||||
"uri": test_uri,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -109,13 +108,13 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value, 'test_user', 'test_collection')
|
||||
processor.create_literal(test_value, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
params={
|
||||
"value": test_value,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -132,17 +131,17 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection')
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
params={
|
||||
"src": src_uri,
|
||||
"dest": dest_uri,
|
||||
"uri": pred_uri,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -159,17 +158,17 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection')
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
params={
|
||||
"src": src_uri,
|
||||
"dest": literal_value,
|
||||
"uri": pred_uri,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -179,7 +178,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Test storing triple with URI object"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple = Triple(
|
||||
|
|
@ -200,21 +198,21 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
# Create object node
|
||||
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
# Create relationship
|
||||
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
]
|
||||
|
||||
assert processor.io.query.call_count == 3
|
||||
|
|
@ -237,21 +235,21 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
|
||||
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
# Create literal object
|
||||
(("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",),
|
||||
{"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",),
|
||||
{"params": {"value": "literal object", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
# Create relationship
|
||||
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
|
||||
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
|
||||
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
|
||||
]
|
||||
|
||||
assert processor.io.query.call_count == 3
|
||||
|
|
@ -265,7 +263,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Test storing multiple triples"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
|
|
@ -291,7 +288,7 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
|
@ -313,7 +310,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Test storing empty triples list"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
|
|
@ -323,7 +319,7 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify no queries were made
|
||||
processor.io.query.assert_not_called()
|
||||
|
|
@ -333,7 +329,6 @@ class TestFalkorDBStorageProcessor:
|
|||
"""Test storing triples with mixed URI and literal objects"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
|
|
@ -359,7 +354,7 @@ class TestFalkorDBStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
|
@ -450,13 +445,13 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri, 'test_user', 'test_collection')
|
||||
processor.create_node(test_uri, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
params={
|
||||
"uri": test_uri,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -470,13 +465,13 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value, 'test_user', 'test_collection')
|
||||
processor.create_literal(test_value, 'test_workspace', 'test_collection')
|
||||
|
||||
processor.io.query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
params={
|
||||
"value": test_value,
|
||||
"user": 'test_user',
|
||||
"workspace": 'test_workspace',
|
||||
"collection": 'test_collection',
|
||||
},
|
||||
)
|
||||
|
|
@ -17,7 +17,6 @@ class TestMemgraphStorageProcessor:
|
|||
"""Create a mock message for testing"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create a test triple
|
||||
|
|
@ -43,7 +42,7 @@ class TestMemgraphStorageProcessor:
|
|||
taskgroup=MagicMock(),
|
||||
id='test-memgraph-storage',
|
||||
graph_host='bolt://localhost:7687',
|
||||
username='test_user',
|
||||
username='test_workspace',
|
||||
password='test_pass',
|
||||
database='test_db'
|
||||
)
|
||||
|
|
@ -105,9 +104,9 @@ class TestMemgraphStorageProcessor:
|
|||
"CREATE INDEX ON :Node(uri)",
|
||||
"CREATE INDEX ON :Literal",
|
||||
"CREATE INDEX ON :Literal(value)",
|
||||
"CREATE INDEX ON :Node(user)",
|
||||
"CREATE INDEX ON :Node(workspace)",
|
||||
"CREATE INDEX ON :Node(collection)",
|
||||
"CREATE INDEX ON :Literal(user)",
|
||||
"CREATE INDEX ON :Literal(workspace)",
|
||||
"CREATE INDEX ON :Literal(collection)"
|
||||
]
|
||||
|
||||
|
|
@ -145,12 +144,12 @@ class TestMemgraphStorageProcessor:
|
|||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.create_node(test_uri, "test_user", "test_collection")
|
||||
processor.create_node(test_uri, "test_workspace", "test_collection")
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri=test_uri,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_=processor.db
|
||||
)
|
||||
|
|
@ -166,12 +165,12 @@ class TestMemgraphStorageProcessor:
|
|||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.create_literal(test_value, "test_user", "test_collection")
|
||||
processor.create_literal(test_value, "test_workspace", "test_collection")
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value=test_value,
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_=processor.db
|
||||
)
|
||||
|
|
@ -190,14 +189,14 @@ class TestMemgraphStorageProcessor:
|
|||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri, "test_user", "test_collection")
|
||||
processor.relate_node(src_uri, pred_uri, dest_uri, "test_workspace", "test_collection")
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src=src_uri, dest=dest_uri, uri=pred_uri,
|
||||
user="test_user", collection="test_collection",
|
||||
workspace="test_workspace", collection="test_collection",
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
|
|
@ -215,14 +214,14 @@ class TestMemgraphStorageProcessor:
|
|||
|
||||
processor.io.execute_query.return_value = mock_result
|
||||
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value, "test_user", "test_collection")
|
||||
processor.relate_literal(src_uri, pred_uri, literal_value, "test_workspace", "test_collection")
|
||||
|
||||
processor.io.execute_query.assert_called_once_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src=src_uri, dest=literal_value, uri=pred_uri,
|
||||
user="test_user", collection="test_collection",
|
||||
workspace="test_workspace", collection="test_collection",
|
||||
database_=processor.db
|
||||
)
|
||||
|
||||
|
|
@ -236,22 +235,22 @@ class TestMemgraphStorageProcessor:
|
|||
o=Term(type=IRI, iri='http://example.com/object')
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
|
||||
processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
|
||||
|
||||
# Verify transaction calls
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}),
|
||||
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
{'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
|
||||
# Create object node
|
||||
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
{'uri': 'http://example.com/object', 'user': 'test_user', 'collection': 'test_collection'}),
|
||||
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
{'uri': 'http://example.com/object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
|
||||
# Create relationship
|
||||
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate',
|
||||
'user': 'test_user', 'collection': 'test_collection'})
|
||||
'workspace': 'test_workspace', 'collection': 'test_collection'})
|
||||
]
|
||||
|
||||
assert mock_tx.run.call_count == 3
|
||||
|
|
@ -270,22 +269,22 @@ class TestMemgraphStorageProcessor:
|
|||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
|
||||
processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
|
||||
|
||||
# Verify transaction calls
|
||||
expected_calls = [
|
||||
# Create subject node
|
||||
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}),
|
||||
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
{'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
|
||||
# Create literal object
|
||||
("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
{'value': 'literal object', 'user': 'test_user', 'collection': 'test_collection'}),
|
||||
("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
{'value': 'literal object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
|
||||
# Create relationship
|
||||
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate',
|
||||
'user': 'test_user', 'collection': 'test_collection'})
|
||||
'workspace': 'test_workspace', 'collection': 'test_collection'})
|
||||
]
|
||||
|
||||
assert mock_tx.run.call_count == 3
|
||||
|
|
@ -314,8 +313,8 @@ class TestMemgraphStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify execute_query was called for create_node, create_literal, and relate_literal
|
||||
# (since mock_message has a literal object)
|
||||
assert processor.io.execute_query.call_count == 3
|
||||
|
|
@ -323,7 +322,7 @@ class TestMemgraphStorageProcessor:
|
|||
# Verify user/collection parameters were included
|
||||
for call in processor.io.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
assert 'user' in call_kwargs
|
||||
assert 'workspace' in call_kwargs
|
||||
assert 'collection' in call_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -343,7 +342,6 @@ class TestMemgraphStorageProcessor:
|
|||
# Create message with multiple triples
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
|
|
@ -364,7 +362,7 @@ class TestMemgraphStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify execute_query was called:
|
||||
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
|
||||
|
|
@ -375,7 +373,7 @@ class TestMemgraphStorageProcessor:
|
|||
# Verify user/collection parameters were included in all calls
|
||||
for call in processor.io.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
assert call_kwargs['user'] == 'test_user'
|
||||
assert call_kwargs['workspace'] == 'test_workspace'
|
||||
assert call_kwargs['collection'] == 'test_collection'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -389,7 +387,6 @@ class TestMemgraphStorageProcessor:
|
|||
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
|
|
@ -399,7 +396,7 @@ class TestMemgraphStorageProcessor:
|
|||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
await processor.store_triples('test_workspace', message)
|
||||
|
||||
# Verify no session calls were made (no triples to process)
|
||||
processor.io.session.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -68,9 +68,9 @@ class TestNeo4jStorageProcessor:
|
|||
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
|
||||
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
|
||||
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
|
||||
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
|
||||
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
|
||||
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
|
||||
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
|
||||
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
|
||||
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
|
||||
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
|
||||
]
|
||||
|
||||
|
|
@ -116,12 +116,12 @@ class TestNeo4jStorageProcessor:
|
|||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test create_node
|
||||
processor.create_node("http://example.com/node", "test_user", "test_collection")
|
||||
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri="http://example.com/node",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
|
@ -146,12 +146,12 @@ class TestNeo4jStorageProcessor:
|
|||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Test create_literal
|
||||
processor.create_literal("literal value", "test_user", "test_collection")
|
||||
processor.create_literal("literal value", "test_workspace", "test_collection")
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value="literal value",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
|
@ -180,18 +180,18 @@ class TestNeo4jStorageProcessor:
|
|||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"http://example.com/object",
|
||||
"test_user",
|
||||
"test_workspace",
|
||||
"test_collection"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="http://example.com/object",
|
||||
uri="http://example.com/predicate",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
|
@ -220,18 +220,18 @@ class TestNeo4jStorageProcessor:
|
|||
"http://example.com/subject",
|
||||
"http://example.com/predicate",
|
||||
"literal value",
|
||||
"test_user",
|
||||
"test_workspace",
|
||||
"test_collection"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_called_with(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject",
|
||||
dest="literal value",
|
||||
uri="http://example.com/predicate",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
|
@ -268,36 +268,35 @@ class TestNeo4jStorageProcessor:
|
|||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("test_workspace", mock_message)
|
||||
|
||||
# Verify create_node was called for subject and object
|
||||
# Verify relate_node was called
|
||||
expected_calls = [
|
||||
# Subject node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
{"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
|
||||
),
|
||||
# Object node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
{"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
{"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
|
||||
),
|
||||
# Relationship creation
|
||||
(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
{
|
||||
"src": "http://example.com/subject",
|
||||
"dest": "http://example.com/object",
|
||||
"uri": "http://example.com/predicate",
|
||||
"user": "test_user",
|
||||
"workspace": "test_workspace",
|
||||
"collection": "test_collection",
|
||||
"database_": "neo4j"
|
||||
}
|
||||
|
|
@ -340,12 +339,11 @@ class TestNeo4jStorageProcessor:
|
|||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("test_workspace", mock_message)
|
||||
|
||||
# Verify create_node was called for subject
|
||||
# Verify create_literal was called for object
|
||||
|
|
@ -353,24 +351,24 @@ class TestNeo4jStorageProcessor:
|
|||
expected_calls = [
|
||||
# Subject node creation
|
||||
(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
{"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
|
||||
),
|
||||
# Literal creation
|
||||
(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
{"value": "literal value", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
{"value": "literal value", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
|
||||
),
|
||||
# Relationship creation
|
||||
(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
{
|
||||
"src": "http://example.com/subject",
|
||||
"dest": "literal value",
|
||||
"uri": "http://example.com/predicate",
|
||||
"user": "test_user",
|
||||
"workspace": "test_workspace",
|
||||
"collection": "test_collection",
|
||||
"database_": "neo4j"
|
||||
}
|
||||
|
|
@ -421,12 +419,11 @@ class TestNeo4jStorageProcessor:
|
|||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple1, triple2]
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("test_workspace", mock_message)
|
||||
|
||||
# Should have processed both triples
|
||||
# Triple1: 2 nodes + 1 relationship = 3 calls
|
||||
|
|
@ -449,12 +446,11 @@ class TestNeo4jStorageProcessor:
|
|||
# Create mock message with empty triples and metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = []
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("test_workspace", mock_message)
|
||||
|
||||
# Should not have made any execute_query calls beyond index creation
|
||||
# Only index creation calls should have been made during initialization
|
||||
|
|
@ -568,38 +564,37 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
await processor.store_triples("test_workspace", mock_message)
|
||||
|
||||
# Verify the triple was processed with special characters preserved
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
|
||||
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
|
||||
uri="http://example.com/subject with spaces",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
|
||||
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
|
||||
value='literal with "quotes" and unicode: ñáéíóú',
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
|
||||
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
|
||||
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
|
||||
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
|
||||
src="http://example.com/subject with spaces",
|
||||
dest='literal with "quotes" and unicode: ñáéíóú',
|
||||
uri="http://example.com/predicate:with/symbols",
|
||||
user="test_user",
|
||||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
database_="neo4j"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,11 +24,10 @@ def _make_processor(qdrant_client=None):
|
|||
return proc
|
||||
|
||||
|
||||
def _make_request(vector=None, user="test-user", collection="test-col",
|
||||
def _make_request(vector=None, collection="test-col",
|
||||
schema_name="customers", limit=10, index_name=None):
|
||||
return RowEmbeddingsRequest(
|
||||
vector=vector or [0.1, 0.2, 0.3],
|
||||
user=user,
|
||||
collection=collection,
|
||||
schema_name=schema_name,
|
||||
limit=limit,
|
||||
|
|
@ -36,6 +35,14 @@ def _make_request(vector=None, user="test-user", collection="test-col",
|
|||
)
|
||||
|
||||
|
||||
def _make_flow(workspace="test-workspace", pub=None):
|
||||
"""Make a mock flow object that is callable and has .workspace."""
|
||||
flow = MagicMock()
|
||||
flow.return_value = pub if pub is not None else AsyncMock()
|
||||
flow.workspace = workspace
|
||||
return flow
|
||||
|
||||
|
||||
def _make_search_point(index_name, index_value, text, score):
|
||||
point = MagicMock()
|
||||
point.payload = {
|
||||
|
|
@ -85,34 +92,33 @@ class TestFindCollection:
|
|||
def test_finds_matching_collection(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_test_user_test_col_customers_384"
|
||||
mock_coll.name = "rows_test_workspace_test_col_customers_384"
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-user", "test-col", "customers")
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
|
||||
# Prefix: rows_test_user_test_col_customers_
|
||||
assert result == "rows_test_user_test_col_customers_384"
|
||||
assert result == "rows_test_workspace_test_col_customers_384"
|
||||
|
||||
def test_returns_none_when_no_match(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_other_user_other_col_schema_768"
|
||||
mock_coll.name = "rows_other_workspace_other_col_schema_768"
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-user", "test-col", "customers")
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_error(self):
|
||||
proc = _make_processor()
|
||||
proc.qdrant.get_collections.side_effect = Exception("connection error")
|
||||
|
||||
result = proc.find_collection("user", "col", "schema")
|
||||
result = proc.find_collection("workspace", "col", "schema")
|
||||
assert result is None
|
||||
|
||||
|
||||
|
|
@ -127,7 +133,7 @@ class TestQueryRowEmbeddings:
|
|||
proc = _make_processor()
|
||||
request = _make_request(vector=[])
|
||||
|
||||
result = await proc.query_row_embeddings(request)
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -136,13 +142,13 @@ class TestQueryRowEmbeddings:
|
|||
proc.find_collection = MagicMock(return_value=None)
|
||||
request = _make_request()
|
||||
|
||||
result = await proc.query_row_embeddings(request)
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_query_returns_matches(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
|
||||
points = [
|
||||
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
|
||||
|
|
@ -153,7 +159,7 @@ class TestQueryRowEmbeddings:
|
|||
proc.qdrant.query_points.return_value = mock_result
|
||||
|
||||
request = _make_request()
|
||||
result = await proc.query_row_embeddings(request)
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], RowIndexMatch)
|
||||
|
|
@ -166,14 +172,14 @@ class TestQueryRowEmbeddings:
|
|||
async def test_index_name_filter_applied(self):
|
||||
"""When index_name is specified, a Qdrant filter should be used."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
proc.qdrant.query_points.return_value = mock_result
|
||||
|
||||
request = _make_request(index_name="address")
|
||||
await proc.query_row_embeddings(request)
|
||||
await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
call_kwargs = proc.qdrant.query_points.call_args[1]
|
||||
assert call_kwargs["query_filter"] is not None
|
||||
|
|
@ -182,14 +188,14 @@ class TestQueryRowEmbeddings:
|
|||
async def test_no_index_name_no_filter(self):
|
||||
"""When index_name is empty, no filter should be applied."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
proc.qdrant.query_points.return_value = mock_result
|
||||
|
||||
request = _make_request(index_name="")
|
||||
await proc.query_row_embeddings(request)
|
||||
await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
call_kwargs = proc.qdrant.query_points.call_args[1]
|
||||
assert call_kwargs["query_filter"] is None
|
||||
|
|
@ -198,7 +204,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_missing_payload_fields_default(self):
|
||||
"""Points with missing payload fields should use defaults."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
|
||||
point = MagicMock()
|
||||
point.payload = {} # Empty payload
|
||||
|
|
@ -209,7 +215,7 @@ class TestQueryRowEmbeddings:
|
|||
proc.qdrant.query_points.return_value = mock_result
|
||||
|
||||
request = _make_request()
|
||||
result = await proc.query_row_embeddings(request)
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].index_name == ""
|
||||
|
|
@ -219,13 +225,13 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_qdrant_error_propagates(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.qdrant.query_points.side_effect = Exception("qdrant down")
|
||||
|
||||
request = _make_request()
|
||||
|
||||
with pytest.raises(Exception, match="qdrant down"):
|
||||
await proc.query_row_embeddings(request)
|
||||
await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -243,7 +249,7 @@ class TestOnMessage:
|
|||
])
|
||||
|
||||
mock_pub = AsyncMock()
|
||||
flow = lambda name: mock_pub
|
||||
flow = _make_flow(pub=mock_pub)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = _make_request()
|
||||
|
|
@ -264,7 +270,7 @@ class TestOnMessage:
|
|||
)
|
||||
|
||||
mock_pub = AsyncMock()
|
||||
flow = lambda name: mock_pub
|
||||
flow = _make_flow(pub=mock_pub)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = _make_request()
|
||||
|
|
@ -284,7 +290,7 @@ class TestOnMessage:
|
|||
proc.query_row_embeddings = AsyncMock(return_value=[])
|
||||
|
||||
mock_pub = AsyncMock()
|
||||
flow = lambda name: mock_pub
|
||||
flow = _make_flow(pub=mock_pub)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = _make_request()
|
||||
|
|
|
|||
|
|
@ -45,12 +45,9 @@ class TestGetGraphEmbeddings:
|
|||
with `vector=` (singular) — the schema field name. A previous
|
||||
version used `vectors=` and TypeError'd at runtime.
|
||||
"""
|
||||
# Arrange — fake row matching the get_triples_stmt result shape:
|
||||
# row[0..2] are unused by the method, row[3] is the entities blob
|
||||
fake_row = (
|
||||
None, None, None,
|
||||
[
|
||||
# ((value, is_uri), vector)
|
||||
(("http://example.org/alice", True), [0.1, 0.2, 0.3]),
|
||||
(("http://example.org/bob", True), [0.4, 0.5, 0.6]),
|
||||
(("a literal entity", False), [0.7, 0.8, 0.9]),
|
||||
|
|
@ -67,14 +64,8 @@ class TestGetGraphEmbeddings:
|
|||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
# Act
|
||||
await store.get_graph_embeddings(
|
||||
user="alice",
|
||||
document_id="doc-1",
|
||||
receiver=receiver,
|
||||
)
|
||||
await store.get_graph_embeddings("alice", "doc-1", receiver)
|
||||
|
||||
# Assert
|
||||
mock_async_execute.assert_called_once_with(
|
||||
store.cassandra,
|
||||
store.get_graph_embeddings_stmt,
|
||||
|
|
@ -86,7 +77,6 @@ class TestGetGraphEmbeddings:
|
|||
assert isinstance(ge, GraphEmbeddings)
|
||||
assert isinstance(ge.metadata, Metadata)
|
||||
assert ge.metadata.id == "doc-1"
|
||||
assert ge.metadata.user == "alice"
|
||||
|
||||
assert len(ge.entities) == 3
|
||||
assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
|
||||
|
|
@ -122,7 +112,7 @@ class TestGetGraphEmbeddings:
|
|||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
await store.get_graph_embeddings("u", "d", receiver)
|
||||
await store.get_graph_embeddings("w", "d", receiver)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].entities == []
|
||||
|
|
@ -149,7 +139,7 @@ class TestGetGraphEmbeddings:
|
|||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
await store.get_graph_embeddings("u", "d", receiver)
|
||||
await store.get_graph_embeddings("w", "d", receiver)
|
||||
|
||||
assert len(received) == 2
|
||||
assert received[0].entities[0].entity.iri == "http://example.org/a"
|
||||
|
|
@ -194,7 +184,6 @@ class TestGetTriples:
|
|||
assert isinstance(triples_msg, Triples)
|
||||
assert isinstance(triples_msg.metadata, Metadata)
|
||||
assert triples_msg.metadata.id == "doc-1"
|
||||
assert triples_msg.metadata.user == "alice"
|
||||
|
||||
assert len(triples_msg.triples) == 1
|
||||
t = triples_msg.triples[0]
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ def sample():
|
|||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
chunks=[
|
||||
|
|
@ -56,7 +55,6 @@ class TestDocumentEmbeddingsTranslator:
|
|||
assert isinstance(decoded, DocumentEmbeddings)
|
||||
assert isinstance(decoded.metadata, Metadata)
|
||||
assert decoded.metadata.id == "doc-1"
|
||||
assert decoded.metadata.user == "alice"
|
||||
assert decoded.metadata.collection == "testcoll"
|
||||
|
||||
assert len(decoded.chunks) == 2
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def translator():
|
|||
def graph_embeddings_request():
|
||||
return KnowledgeRequest(
|
||||
operation="put-kg-core",
|
||||
user="alice",
|
||||
workspace="alice",
|
||||
id="doc-1",
|
||||
flow="default",
|
||||
collection="testcoll",
|
||||
|
|
@ -49,7 +49,6 @@ def graph_embeddings_request():
|
|||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
entities=[
|
||||
|
|
@ -70,7 +69,6 @@ def graph_embeddings_request():
|
|||
def triples_request():
|
||||
return KnowledgeRequest(
|
||||
operation="put-kg-core",
|
||||
user="alice",
|
||||
id="doc-1",
|
||||
flow="default",
|
||||
collection="testcoll",
|
||||
|
|
@ -78,7 +76,6 @@ def triples_request():
|
|||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
triples=[
|
||||
|
|
@ -113,7 +110,7 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings:
|
|||
|
||||
assert isinstance(decoded, KnowledgeRequest)
|
||||
assert decoded.operation == "put-kg-core"
|
||||
assert decoded.user == "alice"
|
||||
assert decoded.workspace == "alice"
|
||||
assert decoded.id == "doc-1"
|
||||
assert decoded.flow == "default"
|
||||
assert decoded.collection == "testcoll"
|
||||
|
|
@ -123,7 +120,6 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings:
|
|||
assert isinstance(ge, GraphEmbeddings)
|
||||
assert isinstance(ge.metadata, Metadata)
|
||||
assert ge.metadata.id == "doc-1"
|
||||
assert ge.metadata.user == "alice"
|
||||
assert ge.metadata.collection == "testcoll"
|
||||
|
||||
assert len(ge.entities) == 2
|
||||
|
|
@ -143,7 +139,6 @@ class TestKnowledgeRequestTranslatorTriples:
|
|||
assert decoded.triples is not None
|
||||
assert isinstance(decoded.triples.metadata, Metadata)
|
||||
assert decoded.triples.metadata.id == "doc-1"
|
||||
assert decoded.triples.metadata.user == "alice"
|
||||
assert decoded.triples.metadata.collection == "testcoll"
|
||||
|
||||
assert len(decoded.triples.triples) == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue