Merge remote-tracking branch 'origin/master' into ts-port

This commit is contained in:
elpresidank 2026-04-26 20:07:57 -05:00
commit f8252ecd54
1038 changed files with 253274 additions and 8466 deletions

View file

@ -37,6 +37,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -50,7 +53,6 @@ class TestAgentServiceNonStreaming:
msg = MagicMock()
msg.value.return_value = AgentRequest(
question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
@ -58,6 +60,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()
@ -78,10 +81,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
r for r in sent_responses if r.message_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
r for r in sent_responses if r.message_type == "explain"
]
# Should have explain events for session, iteration, observation, and final
@ -93,7 +96,7 @@ class TestAgentServiceNonStreaming:
# Check thought message
thought_response = content_responses[0]
assert isinstance(thought_response, AgentResponse)
assert thought_response.chunk_type == "thought"
assert thought_response.message_type == "thought"
assert thought_response.content == "I need to solve this."
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
@ -101,7 +104,7 @@ class TestAgentServiceNonStreaming:
# Check observation message
observation_response = content_responses[1]
assert isinstance(observation_response, AgentResponse)
assert observation_response.chunk_type == "observation"
assert observation_response.message_type == "observation"
assert observation_response.content == "The answer is 4."
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
@ -129,6 +132,9 @@ class TestAgentServiceNonStreaming:
# Setup mock agent manager
mock_agent_instance = AsyncMock()
mock_agent_manager_class.return_value = mock_agent_instance
mock_agent_instance.tools = {}
mock_agent_instance.additional_context = ""
processor.agents["default"] = mock_agent_instance
# Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
@ -140,7 +146,6 @@ class TestAgentServiceNonStreaming:
msg = MagicMock()
msg.value.return_value = AgentRequest(
question="What is 2 + 2?",
user="trustgraph",
streaming=False # Non-streaming mode
)
msg.properties.return_value = {"id": "test-id"}
@ -148,6 +153,7 @@ class TestAgentServiceNonStreaming:
# Setup flow mock
consumer = MagicMock()
flow = MagicMock()
flow.workspace = "default"
mock_producer = AsyncMock()
@ -168,10 +174,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
r for r in sent_responses if r.message_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
r for r in sent_responses if r.message_type == "explain"
]
# Should have explain events for session and final
@ -183,7 +189,7 @@ class TestAgentServiceNonStreaming:
# Check final answer message
answer_response = content_responses[0]
assert isinstance(answer_response, AgentResponse)
assert answer_response.chunk_type == "answer"
assert answer_response.message_type == "answer"
assert answer_response.content == "4"
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"

View file

@ -11,13 +11,12 @@ from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(question="Test question", user="testuser",
def _make_request(question="Test question",
collection="default", streaming=False,
session_id="parent-session", task_type="research",
framing="test framing", conversation_id="conv-1"):
return AgentRequest(
question=question,
user=user,
collection=collection,
streaming=streaming,
session_id=session_id,
@ -127,7 +126,6 @@ class TestBuildSynthesisRequest:
req = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
@ -148,7 +146,7 @@ class TestBuildSynthesisRequest:
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# Last history step should be the synthesis step
@ -168,7 +166,7 @@ class TestBuildSynthesisRequest:
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# Entry should be removed
@ -178,7 +176,7 @@ class TestBuildSynthesisRequest:
agg = Aggregator()
with pytest.raises(RuntimeError, match="No results"):
agg.build_synthesis_request(
"unknown", "question", "user", "default",
"unknown", "question", "default",
)

View file

@ -29,7 +29,7 @@ class TestThinkCallbackMessageId:
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "thought"
assert responses[0].message_type == "thought"
@pytest.mark.asyncio
async def test_non_streaming_think_has_message_id(self, pattern):
@ -58,7 +58,7 @@ class TestObserveCallbackMessageId:
await observe("result", is_final=True)
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "observation"
assert responses[0].message_type == "observation"
class TestAnswerCallbackMessageId:
@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId:
await answer("the answer")
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "answer"
assert responses[0].message_type == "answer"
@pytest.mark.asyncio
async def test_no_message_id_default(self, pattern):

View file

@ -15,7 +15,6 @@ from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
@ -130,7 +129,6 @@ class TestAggregatorIntegration:
synth = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
@ -160,7 +158,7 @@ class TestAggregatorIntegration:
agg.record_completion("corr-1", "goal", "answer")
synth = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
"corr-1", "question", "default",
)
# correlation_id must be empty so it's not intercepted

View file

@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
from trustgraph.base import PromptResult
def _make_config(patterns=None, task_types=None):
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response)
client.prompt = AsyncMock(
return_value=PromptResult(response_type="text", text=prompt_response)
)
def context(service_name):
return client
@ -274,8 +277,8 @@ class TestRoute:
nonlocal call_count
call_count += 1
if call_count == 1:
return "research" # task type
return "plan-then-execute" # pattern
return PromptResult(response_type="text", text="research")
return PromptResult(response_type="text", text="plan-then-execute")
client.prompt = mock_prompt
context = lambda name: client

View file

@ -18,6 +18,7 @@ from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -68,7 +69,7 @@ def collect_explain_events(respond_mock):
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({
"explain_id": resp.explain_id,
"explain_graph": resp.explain_graph,
@ -125,7 +126,6 @@ def make_base_request(**kwargs):
state="",
group=[],
history=[],
user="testuser",
collection="default",
streaming=False,
session_id="test-session-123",
@ -183,7 +183,7 @@ class TestReactPatternProvenance:
)
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
@ -267,7 +267,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(action)
return action
@ -309,7 +309,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(Action(
thought="done", name="final",
@ -355,10 +355,13 @@ class TestPlanPatternProvenance:
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
],
)
def flow_factory(name):
if name == "prompt-request":
@ -418,10 +421,13 @@ class TestPlanPatternProvenance:
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = {
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
}
mock_prompt_client.prompt.return_value = PromptResult(
response_type="json",
object={
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
},
)
def flow_factory(name):
if name == "prompt-request":
@ -475,7 +481,7 @@ class TestPlanPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
def flow_factory(name):
if name == "prompt-request":
@ -542,10 +548,13 @@ class TestSupervisorPatternProvenance:
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
"What is quantum computing?",
"What are qubits?",
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
"What is quantum computing?",
"What are qubits?",
],
)
def flow_factory(name):
if name == "prompt-request":
@ -590,7 +599,7 @@ class TestSupervisorPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
def flow_factory(name):
if name == "prompt-request":
@ -639,7 +648,10 @@ class TestSupervisorPatternProvenance:
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B", "Goal C"],
)
def flow_factory(name):
if name == "prompt-request":

View file

@ -20,7 +20,7 @@ class TestParseChunkMessageId:
def test_thought_message_id(self, client):
resp = {
"chunk_type": "thought",
"message_type": "thought",
"content": "thinking...",
"end_of_message": False,
"message_id": "urn:trustgraph:agent:sess/i1/thought",
@ -31,7 +31,7 @@ class TestParseChunkMessageId:
def test_observation_message_id(self, client):
resp = {
"chunk_type": "observation",
"message_type": "observation",
"content": "result",
"end_of_message": True,
"message_id": "urn:trustgraph:agent:sess/i1/observation",
@ -42,7 +42,7 @@ class TestParseChunkMessageId:
def test_answer_message_id(self, client):
resp = {
"chunk_type": "answer",
"message_type": "answer",
"content": "the answer",
"end_of_message": False,
"end_of_dialog": False,
@ -54,7 +54,7 @@ class TestParseChunkMessageId:
def test_thought_missing_message_id(self, client):
resp = {
"chunk_type": "thought",
"message_type": "thought",
"content": "thinking...",
"end_of_message": False,
}
@ -64,7 +64,7 @@ class TestParseChunkMessageId:
def test_answer_missing_message_id(self, client):
resp = {
"chunk_type": "answer",
"message_type": "answer",
"content": "answer",
"end_of_message": True,
"end_of_dialog": True,

View file

@ -21,7 +21,6 @@ class MockProcessor:
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)

View file

@ -9,6 +9,7 @@ tool usage patterns.
import pytest
from unittest.mock import Mock, AsyncMock
import asyncio
import inspect
from collections import defaultdict
@ -133,7 +134,7 @@ class TestToolCoordinationLogic:
resolved_params[key] = value
# Execute tool
if asyncio.iscoroutinefunction(tool_function):
if inspect.iscoroutinefunction(tool_function):
result = await tool_function(**resolved_params)
else:
result = tool_function(**resolved_params)
@ -227,7 +228,7 @@ class TestToolCoordinationLogic:
# Simulate async execution with delay
await asyncio.sleep(0.001) # Small delay to simulate work
if asyncio.iscoroutinefunction(tool_function):
if inspect.iscoroutinefunction(tool_function):
result = await tool_function(**parameters)
else:
result = tool_function(**parameters)
@ -337,7 +338,7 @@ class TestToolCoordinationLogic:
if attempt > 0:
await asyncio.sleep(0.001 * (self.backoff_factor ** attempt))
if asyncio.iscoroutinefunction(tool_function):
if inspect.iscoroutinefunction(tool_function):
result = await tool_function(**parameters)
else:
result = tool_function(**parameters)

View file

@ -167,39 +167,28 @@ class TestToolServiceRequest:
"""Test cases for tool service request format"""
def test_request_format(self):
"""Test that request is properly formatted with user, config, and arguments"""
# Arrange
user = "alice"
"""Test that request is properly formatted with config and arguments"""
config_values = {"style": "pun", "collection": "jokes"}
arguments = {"topic": "programming"}
# Act - simulate request building
request = {
"user": user,
"config": json.dumps(config_values),
"arguments": json.dumps(arguments)
}
# Assert
assert request["user"] == "alice"
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
assert json.loads(request["arguments"]) == {"topic": "programming"}
def test_request_with_empty_config(self):
"""Test request when no config values are provided"""
# Arrange
user = "bob"
config_values = {}
arguments = {"query": "test"}
# Act
request = {
"user": user,
"config": json.dumps(config_values) if config_values else "{}",
"arguments": json.dumps(arguments) if arguments else "{}"
}
# Assert
assert request["config"] == "{}"
assert json.loads(request["arguments"]) == {"query": "test"}
@ -386,18 +375,13 @@ class TestJokeServiceLogic:
assert map_topic_to_category("random topic") == "default"
assert map_topic_to_category("") == "default"
def test_joke_response_personalization(self):
"""Test that joke responses include user personalization"""
# Arrange
user = "alice"
def test_joke_response_format(self):
"""Test that joke response is formatted as expected"""
style = "pun"
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
# Act
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
response = f"Here's a {style} for you:\n\n{joke}"
# Assert
assert "Hey alice!" in response
assert "pun" in response
assert joke in response
@ -439,20 +423,14 @@ class TestDynamicToolServiceBase:
def test_request_parsing(self):
"""Test parsing of incoming request"""
# Arrange
request_data = {
"user": "alice",
"config": '{"style": "pun"}',
"arguments": '{"topic": "programming"}'
}
# Act
user = request_data.get("user", "trustgraph")
config = json.loads(request_data["config"]) if request_data["config"] else {}
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
# Assert
assert user == "alice"
assert config == {"style": "pun"}
assert arguments == {"topic": "programming"}

View file

@ -1,6 +1,6 @@
"""
Tests for tool service lifecycle, invoke contract, streaming responses,
multi-tenancy, and error propagation.
and error propagation.
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
classes rather than plain dicts.
@ -31,7 +31,7 @@ class TestDynamicToolServiceInvokeContract:
svc = DynamicToolService.__new__(DynamicToolService)
with pytest.raises(NotImplementedError):
await svc.invoke("user", {}, {})
await svc.invoke({}, {})
@pytest.mark.asyncio
async def test_on_request_calls_invoke_with_parsed_args(self):
@ -44,8 +44,8 @@ class TestDynamicToolServiceInvokeContract:
calls = []
async def tracking_invoke(user, config, arguments):
calls.append({"user": user, "config": config, "arguments": arguments})
async def tracking_invoke(config, arguments):
calls.append({"config": config, "arguments": arguments})
return "ok"
svc.invoke = tracking_invoke
@ -56,7 +56,6 @@ class TestDynamicToolServiceInvokeContract:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user="alice",
config='{"style": "pun"}',
arguments='{"topic": "cats"}',
)
@ -65,39 +64,9 @@ class TestDynamicToolServiceInvokeContract:
await svc.on_request(msg, MagicMock(), None)
assert len(calls) == 1
assert calls[0]["user"] == "alice"
assert calls[0]["config"] == {"style": "pun"}
assert calls[0]["arguments"] == {"topic": "cats"}
@pytest.mark.asyncio
async def test_on_request_empty_user_defaults_to_trustgraph(self):
"""Empty user field should default to 'trustgraph'."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
received_user = None
async def capture_invoke(user, config, arguments):
nonlocal received_user
received_user = user
return "ok"
svc.invoke = capture_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
msg.properties.return_value = {"id": "req-2"}
await svc.on_request(msg, MagicMock(), None)
assert received_user == "trustgraph"
@pytest.mark.asyncio
async def test_on_request_string_response_sent_directly(self):
"""String return from invoke → response field is the string."""
@ -107,7 +76,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def string_invoke(user, config, arguments):
async def string_invoke(config, arguments):
return "hello world"
svc.invoke = string_invoke
@ -116,7 +85,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r1"}
await svc.on_request(msg, MagicMock(), None)
@ -136,7 +105,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def dict_invoke(user, config, arguments):
async def dict_invoke(config, arguments):
return {"result": 42}
svc.invoke = dict_invoke
@ -145,7 +114,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r2"}
await svc.on_request(msg, MagicMock(), None)
@ -162,13 +131,13 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def failing_invoke(user, config, arguments):
async def failing_invoke(config, arguments):
raise ValueError("bad input")
svc.invoke = failing_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r3"}
await svc.on_request(msg, MagicMock(), None)
@ -188,13 +157,13 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def rate_limited_invoke(user, config, arguments):
async def rate_limited_invoke(config, arguments):
raise TooManyRequests("rate limited")
svc.invoke = rate_limited_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "r4"}
with pytest.raises(TooManyRequests):
@ -209,7 +178,7 @@ class TestDynamicToolServiceInvokeContract:
svc.id = "test-svc"
svc.producer = AsyncMock()
async def ok_invoke(user, config, arguments):
async def ok_invoke(config, arguments):
return "ok"
svc.invoke = ok_invoke
@ -218,7 +187,7 @@ class TestDynamicToolServiceInvokeContract:
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.value.return_value = ToolServiceRequest(config="{}", arguments="{}")
msg.properties.return_value = {"id": "unique-42"}
await svc.on_request(msg, MagicMock(), None)
@ -241,7 +210,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return "tool result"
svc.invoke_tool = mock_invoke
@ -260,6 +229,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
@ -280,7 +250,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
async def mock_invoke(workspace, name, params):
return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke
@ -298,6 +268,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -317,7 +288,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def failing_invoke(name, params):
async def failing_invoke(workspace, name, params):
raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke
@ -330,6 +301,7 @@ class TestToolServiceOnRequest:
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
flow_callable.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
@ -350,7 +322,7 @@ class TestToolServiceOnRequest:
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def rate_limited(name, params):
async def rate_limited(workspace, name, params):
raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited
@ -362,6 +334,7 @@ class TestToolServiceOnRequest:
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
flow.name = "test-flow"
flow.workspace = "default"
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow)
@ -376,7 +349,8 @@ class TestToolServiceOnRequest:
received = {}
async def capture_invoke(name, params):
async def capture_invoke(workspace, name, params):
received["workspace"] = workspace
received["name"] = name
received["params"] = params
return "ok"
@ -390,6 +364,7 @@ class TestToolServiceOnRequest:
flow = lambda name: mock_pub
flow.producer = {"response": mock_pub}
flow.name = "f"
flow.workspace = "default"
msg = MagicMock()
msg.value.return_value = ToolRequest(
@ -421,7 +396,6 @@ class TestToolServiceClientCall:
))
result = await client.call(
user="alice",
config={"style": "pun"},
arguments={"topic": "cats"},
)
@ -430,7 +404,6 @@ class TestToolServiceClientCall:
req = client.request.call_args[0][0]
assert isinstance(req, ToolServiceRequest)
assert req.user == "alice"
assert json.loads(req.config) == {"style": "pun"}
assert json.loads(req.arguments) == {"topic": "cats"}
@ -446,7 +419,7 @@ class TestToolServiceClientCall:
))
with pytest.raises(RuntimeError, match="service down"):
await client.call(user="u", config={}, arguments={})
await client.call(config={}, arguments={})
@pytest.mark.asyncio
async def test_call_empty_config_sends_empty_json(self):
@ -458,7 +431,7 @@ class TestToolServiceClientCall:
error=None, response="ok",
))
await client.call(user="u", config=None, arguments=None)
await client.call(config=None, arguments=None)
req = client.request.call_args[0][0]
assert req.config == "{}"
@ -474,7 +447,7 @@ class TestToolServiceClientCall:
error=None, response="ok",
))
await client.call(user="u", config={}, arguments={}, timeout=30)
await client.call(config={}, arguments={}, timeout=30)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 30
@ -509,7 +482,7 @@ class TestToolServiceClientStreaming:
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
config={}, arguments={}, callback=callback,
)
assert result == "chunk1chunk2"
@ -534,7 +507,7 @@ class TestToolServiceClientStreaming:
with pytest.raises(RuntimeError, match="stream failed"):
await client.call_streaming(
user="u", config={}, arguments={},
config={}, arguments={},
callback=AsyncMock(),
)
@ -564,61 +537,9 @@ class TestToolServiceClientStreaming:
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
config={}, arguments={}, callback=callback,
)
# Empty response is falsy, so callback shouldn't be called for it
assert result == "data"
assert received == ["data"]
# ---------------------------------------------------------------------------
# Multi-tenancy
# ---------------------------------------------------------------------------
class TestMultiTenancy:
@pytest.mark.asyncio
async def test_user_propagated_to_invoke(self):
"""User from request should reach the invoke method."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test"
svc.producer = AsyncMock()
users_seen = []
async def tracking(user, config, arguments):
users_seen.append(user)
return "ok"
svc.invoke = tracking
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
for u in ["tenant-a", "tenant-b", "tenant-c"]:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user=u, config="{}", arguments="{}",
)
msg.properties.return_value = {"id": f"req-{u}"}
await svc.on_request(msg, MagicMock(), None)
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
@pytest.mark.asyncio
async def test_client_sends_user_in_request(self):
"""ToolServiceClient.call should include user in request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="isolated-tenant", config={}, arguments={})
req = client.request.call_args[0][0]
assert req.user == "isolated-tenant"

View file

@ -1,17 +1,14 @@
"""
Tests for AsyncProcessor config notify pattern:
- register_config_handler with types filtering
- on_config_notify version comparison and type matching
- fetch_config with short-lived client
- fetch_and_apply_config retry logic
- on_config_notify version comparison, type/workspace matching
- fetch_and_apply_config retry logic over per-workspace fetches
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, Mock
from trustgraph.schema import Term, IRI, LITERAL
# Patch heavy dependencies before importing AsyncProcessor
@pytest.fixture
def processor():
"""Create an AsyncProcessor with mocked dependencies."""
@ -68,6 +65,13 @@ class TestRegisterConfigHandler:
assert len(processor.config_handlers) == 2
def _notify_msg(version, changes):
"""Build a Mock config-notify message with given version and changes dict."""
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestOnConfigNotify:
@pytest.mark.asyncio
@ -77,9 +81,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=3, types=["prompt"])
msg = _notify_msg(3, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -91,9 +93,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=5, types=["prompt"])
msg = _notify_msg(5, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -105,9 +105,7 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=2, types=["schema"])
msg = _notify_msg(2, {"schema": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@ -121,40 +119,36 @@ class TestOnConfigNotify:
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
# Mock fetch_config
mock_config = {"prompt": {"key": "value"}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={"key": "value"},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_called_once_with(
"default", {"prompt": {"key": "value"}}, 2
)
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_handler_without_types_always_called(self, processor):
async def test_handler_without_types_ignored_on_notify(self, processor):
"""Handlers registered without types never fire on notifications."""
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler) # No types = all
processor.register_config_handler(handler) # No types
mock_config = {"anything": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["whatever"])
msg = _notify_msg(2, {"whatever": ["default"]})
await processor.on_config_notify(msg, None, None)
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
handler.assert_not_called()
# Version still advances past the notify
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_mixed_handlers_type_filtering(self, processor):
@ -168,156 +162,149 @@ class TestOnConfigNotify:
processor.register_config_handler(schema_handler, types=["schema"])
processor.register_config_handler(all_handler)
mock_config = {"prompt": {}}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
await processor.on_config_notify(msg, None, None)
prompt_handler.assert_called_once()
prompt_handler.assert_called_once_with(
"default", {"prompt": {}}, 2
)
schema_handler.assert_not_called()
all_handler.assert_called_once()
all_handler.assert_not_called()
@pytest.mark.asyncio
async def test_empty_types_invokes_all(self, processor):
"""Empty types list (startup signal) should invoke all handlers."""
async def test_multi_workspace_notify_invokes_handler_per_ws(
self, processor
):
"""Notify affecting multiple workspaces invokes handler once per workspace."""
processor.config_version = 1
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
mock_config = {}
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
return_value=(mock_config, 2)
return_value={},
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=[])
msg = _notify_msg(2, {"prompt": ["ws1", "ws2"]})
await processor.on_config_notify(msg, None, None)
h1.assert_called_once()
h2.assert_called_once()
assert handler.call_count == 2
called_workspaces = {c.args[0] for c in handler.call_args_list}
assert called_workspaces == {"ws1", "ws2"}
@pytest.mark.asyncio
async def test_fetch_failure_handled(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler)
processor.register_config_handler(handler, types=["prompt"])
mock_client = AsyncMock()
with patch.object(
processor, 'fetch_config',
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_workspace',
new_callable=AsyncMock,
side_effect=RuntimeError("Connection failed")
side_effect=RuntimeError("Connection failed"),
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
msg = _notify_msg(2, {"prompt": ["default"]})
# Should not raise
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
class TestFetchConfig:
@pytest.mark.asyncio
async def test_fetch_returns_config_and_version(self, processor):
mock_resp = Mock()
mock_resp.error = None
mock_resp.config = {"prompt": {"key": "val"}}
mock_resp.version = 42
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
config, version = await processor.fetch_config()
assert config == {"prompt": {"key": "val"}}
assert version == 42
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_raises_on_error_response(self, processor):
mock_resp = Mock()
mock_resp.error = Mock(message="not found")
mock_resp.config = {}
mock_resp.version = 0
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(RuntimeError, match="Config error"):
await processor.fetch_config()
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_stops_client_on_exception(self, processor):
mock_client = AsyncMock()
mock_client.request.side_effect = TimeoutError("timeout")
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(TimeoutError):
await processor.fetch_config()
mock_client.stop.assert_called_once()
class TestFetchAndApplyConfig:
@pytest.mark.asyncio
async def test_applies_config_to_all_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
async def test_applies_config_per_workspace(self, processor):
"""Startup fetch invokes handler once per workspace affected."""
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {
"ws1": {"k": "v1"},
"ws2": {"k": "v2"},
}, 10
mock_config = {"prompt": {}, "schema": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 10)
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
# On startup, all handlers are invoked regardless of type
h1.assert_called_once_with(mock_config, 10)
h2.assert_called_once_with(mock_config, 10)
assert h.call_count == 2
call_map = {c.args[0]: c.args[1] for c in h.call_args_list}
assert call_map["ws1"] == {"prompt": {"k": "v1"}}
assert call_map["ws2"] == {"prompt": {"k": "v2"}}
assert processor.config_version == 10
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
call_count = 0
mock_config = {"prompt": {}}
async def test_handler_without_types_skipped_at_startup(self, processor):
"""Handlers registered without types fetch nothing at startup."""
typed = AsyncMock()
untyped = AsyncMock()
processor.register_config_handler(typed, types=["prompt"])
processor.register_config_handler(untyped)
async def mock_fetch():
mock_client = AsyncMock()
async def fake_fetch_all(client, config_type):
return {"default": {}}, 1
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
):
await processor.fetch_and_apply_config()
typed.assert_called_once()
untyped.assert_not_called()
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
h = AsyncMock()
processor.register_config_handler(h, types=["prompt"])
call_count = 0
async def fake_fetch_all(client, config_type):
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("not ready")
return mock_config, 5
return {"default": {"k": "v"}}, 5
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
patch('asyncio.sleep', new_callable=AsyncMock):
mock_client = AsyncMock()
with patch.object(
processor, '_create_config_client', return_value=mock_client
), patch.object(
processor, '_fetch_type_all_workspaces',
new=fake_fetch_all,
), patch('asyncio.sleep', new_callable=AsyncMock):
await processor.fetch_and_apply_config()
assert call_count == 3
assert processor.config_version == 5
h.assert_called_once_with(
"default", {"prompt": {"k": "v"}}, 5
)

View file

@ -33,7 +33,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
result = await client.query(
vector=vector,
limit=10,
user="test_user",
collection="test_collection",
timeout=30
)
@ -45,7 +44,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vector == vector
assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection"
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
@ -104,7 +102,6 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert call_args.limit == 20 # Default limit
assert call_args.user == "trustgraph" # Default user
assert call_args.collection == "default" # Default collection
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')

View file

@ -0,0 +1,81 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
from trustgraph.base.flow import Flow
from trustgraph.base.parameter_spec import Parameter, ParameterSpec
from trustgraph.base.spec import Spec
def test_parameter_spec_is_a_spec_and_adds_parameter_value():
spec = ParameterSpec("temperature")
flow = MagicMock(parameter={})
processor = MagicMock()
spec.add(flow, processor, {"parameters": {"temperature": 0.7}})
assert isinstance(spec, Spec)
assert "temperature" in flow.parameter
assert isinstance(flow.parameter["temperature"], Parameter)
assert flow.parameter["temperature"].value == 0.7
def test_parameter_spec_defaults_missing_values_to_none():
spec = ParameterSpec("model")
flow = MagicMock(parameter={})
spec.add(flow, MagicMock(), {})
assert flow.parameter["model"].value is None
def test_parameter_start_and_stop_are_awaitable():
parameter = Parameter("value")
assert asyncio.run(parameter.start()) is None
assert asyncio.run(parameter.stop()) is None
def test_flow_initialization_calls_registered_specs():
spec_one = MagicMock()
spec_two = MagicMock()
processor = MagicMock(specifications=[spec_one, spec_two])
flow = Flow("processor-1", "flow-a", "default", processor, {"answer": 42})
assert flow.id == "processor-1"
assert flow.name == "flow-a"
assert flow.workspace == "default"
assert flow.producer == {}
assert flow.consumer == {}
assert flow.parameter == {}
spec_one.add.assert_called_once_with(flow, processor, {"answer": 42})
spec_two.add.assert_called_once_with(flow, processor, {"answer": 42})
def test_flow_start_and_stop_visit_all_consumers():
consumer_one = AsyncMock()
consumer_two = AsyncMock()
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.consumer = {"one": consumer_one, "two": consumer_two}
asyncio.run(flow.start())
asyncio.run(flow.stop())
consumer_one.start.assert_called_once_with()
consumer_two.start.assert_called_once_with()
consumer_one.stop.assert_called_once_with()
consumer_two.stop.assert_called_once_with()
def test_flow_call_returns_values_in_priority_order():
flow = Flow("processor-1", "flow-a", "default", MagicMock(specifications=[]), {})
flow.producer["shared"] = "producer-value"
flow.consumer["consumer-only"] = "consumer-value"
flow.consumer["shared"] = "consumer-value"
flow.parameter["parameter-only"] = Parameter("parameter-value")
flow.parameter["shared"] = Parameter("parameter-value")
assert flow("shared") == "producer-value"
assert flow("consumer-only") == "consumer-value"
assert flow("parameter-only") == "parameter-value"
assert flow("missing") is None

View file

@ -172,10 +172,10 @@ class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
mock_flow_class.assert_called_once_with('test-processor', flow_name, "default", processor, flow_defn)
# The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation)

View file

@ -1,58 +1,50 @@
"""
Unit tests for trustgraph.base.flow_processor
Starting small with a single test to verify basic functionality
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.base.flow_processor import FlowProcessor
# Patches needed to let AsyncProcessor.__init__ run without real
# infrastructure while still setting self.id correctly.
ASYNC_PROCESSOR_PATCHES = [
patch('trustgraph.base.async_processor.get_pubsub', return_value=MagicMock()),
patch('trustgraph.base.async_processor.ProcessorMetrics', return_value=MagicMock()),
patch('trustgraph.base.async_processor.Consumer', return_value=MagicMock()),
]
def with_async_processor_patches(func):
"""Apply all AsyncProcessor dependency patches to a test."""
for p in reversed(ASYNC_PROCESSOR_PATCHES):
func = p(func)
return func
class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
"""Test FlowProcessor base class functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_flow_processor_initialization_basic(self, mock_register_config, mock_async_init):
@with_async_processor_patches
async def test_flow_processor_initialization_basic(self, *mocks):
"""Test basic FlowProcessor initialization"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'taskgroup': AsyncMock()
}
# Act
processor = FlowProcessor(**config)
# Assert
# Verify AsyncProcessor.__init__ was called
mock_async_init.assert_called_once()
# Verify register_config_handler was called with the correct handler
mock_register_config.assert_called_once_with(
processor.on_configure_flows, types=["active-flow"]
)
# Verify FlowProcessor-specific initialization
assert hasattr(processor, 'flows')
assert processor.id == 'test-flow-processor'
assert processor.flows == {}
assert hasattr(processor, 'specifications')
assert processor.specifications == []
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_register_specification(self, mock_register_config, mock_async_init):
@with_async_processor_patches
async def test_register_specification(self, *mocks):
"""Test registering a specification"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'taskgroup': AsyncMock()
@ -62,288 +54,210 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_spec = MagicMock()
mock_spec.name = 'test-spec'
# Act
processor.register_specification(mock_spec)
# Assert
assert len(processor.specifications) == 1
assert processor.specifications[0] == mock_spec
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_start_flow(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_start_flow(self, *mocks):
"""Test starting a flow"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_flow_class = mocks[-1]
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor' # Set id for Flow creation
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
await processor.start_flow("default", flow_name, flow_defn)
# Assert
assert flow_name in processor.flows
# Verify Flow was created with correct parameters
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
# Verify the flow's start method was called
assert ("default", flow_name) in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', flow_name, "default", processor, flow_defn
)
mock_flow.start.assert_called_once()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_stop_flow(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_stop_flow(self, *mocks):
"""Test stopping a flow"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_flow_class = mocks[-1]
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
await processor.start_flow("default", flow_name, {'config': 'test-config'})
# Start a flow first
await processor.start_flow(flow_name, flow_defn)
# Act
await processor.stop_flow(flow_name)
await processor.stop_flow("default", flow_name)
# Assert
assert flow_name not in processor.flows
assert ("default", flow_name) not in processor.flows
mock_flow.stop.assert_called_once()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_stop_flow_not_exists(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_stop_flow_not_exists(self, *mocks):
"""Test stopping a flow that doesn't exist"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Act - should not raise an exception
await processor.stop_flow('non-existent-flow')
# Assert - flows dict should still be empty
await processor.stop_flow("default", 'non-existent-flow')
assert processor.flows == {}
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_on_configure_flows_basic(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_on_configure_flows_basic(self, *mocks):
"""Test basic flow configuration handling"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_flow_class = mocks[-1]
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
# Configuration with flows for this processor
flow_config = {
'test-flow': {'config': 'test-config'}
}
config_data = {
'active-flow': {
'test-processor': '{"test-flow": {"config": "test-config"}}'
'processor:test-processor': {
'test-flow': '{"config": "test-config"}'
}
}
# Act
await processor.on_configure_flows(config_data, version=1)
# Assert
assert 'test-flow' in processor.flows
mock_flow_class.assert_called_once_with('test-processor', 'test-flow', processor, {'config': 'test-config'})
await processor.on_configure_flows("default", config_data, version=1)
assert ("default", 'test-flow') in processor.flows
mock_flow_class.assert_called_once_with(
'test-processor', 'test-flow', "default", processor,
{'config': 'test-config'}
)
mock_flow.start.assert_called_once()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_on_configure_flows_no_config(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_on_configure_flows_no_config(self, *mocks):
"""Test flow configuration handling when no config exists for this processor"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
# Configuration without flows for this processor
config_data = {
'active-flow': {
'other-processor': '{"other-flow": {"config": "other-config"}}'
'processor:other-processor': {
'other-flow': '{"config": "other-config"}'
}
}
# Act
await processor.on_configure_flows(config_data, version=1)
# Assert
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
mock_flow_class.assert_not_called()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_on_configure_flows_invalid_config(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_on_configure_flows_invalid_config(self, *mocks):
"""Test flow configuration handling with invalid config format"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
# Configuration without active-flow key
config_data = {
'other-data': 'some-value'
}
# Act
await processor.on_configure_flows(config_data, version=1)
# Assert
await processor.on_configure_flows("default", config_data, version=1)
assert processor.flows == {}
mock_flow_class.assert_not_called()
@patch('trustgraph.base.flow_processor.Flow')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_on_configure_flows_start_and_stop(self, mock_register_config, mock_async_init, mock_flow_class):
@with_async_processor_patches
async def test_on_configure_flows_start_and_stop(self, *mocks):
"""Test flow configuration handling with starting and stopping flows"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_flow_class = mocks[-1]
config = {
'id': 'test-flow-processor',
'id': 'test-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
mock_flow1 = AsyncMock()
mock_flow2 = AsyncMock()
mock_flow_class.side_effect = [mock_flow1, mock_flow2]
# First configuration - start flow1
config_data1 = {
'active-flow': {
'test-processor': '{"flow1": {"config": "config1"}}'
'processor:test-processor': {
'flow1': '{"config": "config1"}'
}
}
await processor.on_configure_flows(config_data1, version=1)
await processor.on_configure_flows("default", config_data1, version=1)
# Second configuration - stop flow1, start flow2
config_data2 = {
'active-flow': {
'test-processor': '{"flow2": {"config": "config2"}}'
'processor:test-processor': {
'flow2': '{"config": "config2"}'
}
}
# Act
await processor.on_configure_flows(config_data2, version=2)
# Assert
# flow1 should be stopped and removed
assert 'flow1' not in processor.flows
await processor.on_configure_flows("default", config_data2, version=2)
assert ("default", 'flow1') not in processor.flows
mock_flow1.stop.assert_called_once()
# flow2 should be started and added
assert 'flow2' in processor.flows
assert ("default", 'flow2') in processor.flows
mock_flow2.start.assert_called_once()
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
@with_async_processor_patches
@patch('trustgraph.base.async_processor.AsyncProcessor.start')
async def test_start_calls_parent(self, mock_parent_start, mock_register_config, mock_async_init):
async def test_start_calls_parent(self, mock_parent_start, *mocks):
"""Test that start() calls parent start method"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_parent_start.return_value = None
config = {
'id': 'test-flow-processor',
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Act
await processor.start()
# Assert
mock_parent_start.assert_called_once()
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
async def test_add_args_calls_parent(self, mock_register_config, mock_async_init):
async def test_add_args_calls_parent(self):
"""Test that add_args() calls parent add_args method"""
# Arrange
mock_async_init.return_value = None
mock_register_config.return_value = None
mock_parser = MagicMock()
# Act
with patch('trustgraph.base.async_processor.AsyncProcessor.add_args') as mock_parent_add_args:
FlowProcessor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
if __name__ == '__main__':
pytest.main([__file__])
pytest.main([__file__])

View file

@ -0,0 +1,40 @@
from trustgraph.i18n import get_language_pack, get_translator, normalize_language
def test_normalize_language_handles_regions_and_accept_language():
assert normalize_language(None) == "en"
assert normalize_language("") == "en"
assert normalize_language("es-ES") == "es"
assert normalize_language("pt-BR") == "pt"
assert normalize_language("zh") == "zh-cn"
assert normalize_language("es-ES,es;q=0.9,en;q=0.8") == "es"
assert normalize_language("unknown") == "en"
def test_language_pack_loads_from_resources():
pack = get_language_pack("en")
assert isinstance(pack, dict)
# Key should exist and map to a non-empty string.
title = pack.get("cli.verify_system_status.title")
assert isinstance(title, str)
assert title.strip() != ""
def test_translator_formats_placeholders():
tr = get_translator("en")
out = tr.t(
"cli.verify_system_status.checking_attempt",
name="Pulsar",
attempt=2,
)
assert "Pulsar" in out
assert "2" in out
def test_translator_falls_back_to_key_for_unknown_keys():
tr = get_translator("en")
assert tr.t("missing.key") == "missing.key"

View file

@ -0,0 +1,130 @@
import argparse
import logging
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock
from trustgraph.base.logging import add_logging_args, setup_logging
def test_add_logging_args_uses_environment_defaults(monkeypatch):
monkeypatch.setenv("LOKI_URL", "http://example.test/loki")
monkeypatch.setenv("LOKI_USERNAME", "user")
monkeypatch.setenv("LOKI_PASSWORD", "pass")
parser = argparse.ArgumentParser()
add_logging_args(parser)
args = parser.parse_args([])
assert args.log_level == "INFO"
assert args.loki_enabled is True
assert args.loki_url == "http://example.test/loki"
assert args.loki_username == "user"
assert args.loki_password == "pass"
def test_add_logging_args_supports_disabling_loki():
parser = argparse.ArgumentParser()
add_logging_args(parser)
args = parser.parse_args(["--no-loki-enabled"])
assert args.loki_enabled is False
def test_setup_logging_without_loki_configures_console(monkeypatch):
basic_config = MagicMock()
logger = MagicMock()
monkeypatch.setattr(logging, "basicConfig", basic_config)
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger)
setup_logging({"log_level": "debug", "loki_enabled": False, "id": "processor-1"})
kwargs = basic_config.call_args.kwargs
assert kwargs["level"] == logging.DEBUG
assert kwargs["force"] is True
assert "%(processor_id)s" in kwargs["format"]
assert len(kwargs["handlers"]) == 1
logger.info.assert_called_once_with("Logging configured with level: debug")
def test_setup_logging_with_loki_enables_queue_listener(monkeypatch):
basic_config = MagicMock()
root_logger = MagicMock()
module_logger = MagicMock()
urllib3_logger = MagicMock()
connectionpool_logger = MagicMock()
queue_handler = MagicMock()
queue_listener = MagicMock()
loki_handler = MagicMock()
noisy_logger = MagicMock()
logger_map = {
None: root_logger,
"trustgraph.base.logging": module_logger,
"urllib3": urllib3_logger,
"urllib3.connectionpool": connectionpool_logger,
"pika": noisy_logger,
"cassandra": noisy_logger,
}
monkeypatch.setattr(logging, "basicConfig", basic_config)
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger_map[name])
monkeypatch.setattr(
logging.handlers,
"QueueHandler",
MagicMock(return_value=queue_handler),
)
monkeypatch.setattr(
logging.handlers,
"QueueListener",
MagicMock(return_value=queue_listener),
)
monkeypatch.setitem(
sys.modules,
"logging_loki",
SimpleNamespace(LokiHandler=MagicMock(return_value=loki_handler)),
)
setup_logging(
{
"log_level": "INFO",
"loki_enabled": True,
"loki_url": "http://loki.test/push",
"loki_username": "user",
"loki_password": "pass",
"id": "processor-1",
}
)
assert root_logger.loki_queue_listener is queue_listener
queue_listener.start.assert_called_once_with()
urllib3_logger.setLevel.assert_called_once_with(logging.WARNING)
connectionpool_logger.setLevel.assert_called_once_with(logging.WARNING)
module_logger.info.assert_any_call("Logging configured with level: INFO")
module_logger.info.assert_any_call("Loki logging enabled: http://loki.test/push")
def test_setup_logging_falls_back_when_loki_module_missing(monkeypatch, capsys):
basic_config = MagicMock()
logger = MagicMock()
monkeypatch.setattr(logging, "basicConfig", basic_config)
monkeypatch.setattr(logging, "getLogger", lambda name=None: logger)
monkeypatch.delitem(sys.modules, "logging_loki", raising=False)
real_import = __import__
def fake_import(name, *args, **kwargs):
if name == "logging_loki":
raise ImportError("missing")
return real_import(name, *args, **kwargs)
monkeypatch.setattr("builtins.__import__", fake_import)
setup_logging({"log_level": "INFO", "loki_enabled": True, "id": "processor-1"})
output = capsys.readouterr().out
assert "python-logging-loki not installed" in output
logger.warning.assert_called_once_with("Loki logging requested but not available")

View file

@ -0,0 +1,143 @@
from unittest.mock import MagicMock
import pytest
from trustgraph.base import metrics
@pytest.fixture(autouse=True)
def reset_metric_singletons():
"""Temporarily remove metric singletons so each test can inject mocks.
Saves any existing class-level metrics and restores them after the test
so that later tests in the same process still find the hasattr() guard
intact deleting without restoring causes every subsequent Processor()
construction to re-register the same Prometheus metric name, which raises
ValueError: Duplicated timeseries.
"""
classes_and_attrs = {
metrics.ConsumerMetrics: [
"state_metric",
"request_metric",
"processing_metric",
"rate_limit_metric",
],
metrics.ProducerMetrics: ["producer_metric"],
metrics.ProcessorMetrics: ["processor_metric"],
metrics.SubscriberMetrics: [
"state_metric",
"received_metric",
"dropped_metric",
],
}
saved = {}
for cls, attrs in classes_and_attrs.items():
for attr in attrs:
if hasattr(cls, attr):
saved[(cls, attr)] = getattr(cls, attr)
delattr(cls, attr)
yield
# Remove anything the test may have set, then restore originals
for cls, attrs in classes_and_attrs.items():
for attr in attrs:
if hasattr(cls, attr):
delattr(cls, attr)
for (cls, attr), value in saved.items():
setattr(cls, attr, value)
def test_consumer_metrics_reuses_singletons_and_records_events(monkeypatch):
enum_factory = MagicMock()
histogram_factory = MagicMock()
counter_factory = MagicMock()
state_labels = MagicMock()
request_labels = MagicMock()
processing_labels = MagicMock()
rate_limit_labels = MagicMock()
timer = MagicMock()
enum_factory.return_value.labels.return_value = state_labels
histogram_factory.return_value.labels.return_value = request_labels
request_labels.time.return_value = timer
counter_factory.side_effect = [
MagicMock(labels=MagicMock(return_value=processing_labels)),
MagicMock(labels=MagicMock(return_value=rate_limit_labels)),
]
monkeypatch.setattr(metrics, "Enum", enum_factory)
monkeypatch.setattr(metrics, "Histogram", histogram_factory)
monkeypatch.setattr(metrics, "Counter", counter_factory)
first = metrics.ConsumerMetrics("proc", "flow", "name")
second = metrics.ConsumerMetrics("proc-2", "flow-2", "name-2")
assert enum_factory.call_count == 1
assert histogram_factory.call_count == 1
assert counter_factory.call_count == 2
first.process("ok")
first.rate_limit()
first.state("running")
assert first.record_time() is timer
processing_labels.inc.assert_called_once_with()
rate_limit_labels.inc.assert_called_once_with()
state_labels.state.assert_called_once_with("running")
def test_producer_metrics_increments_counter_once(monkeypatch):
counter_factory = MagicMock()
labels = MagicMock()
counter_factory.return_value.labels.return_value = labels
monkeypatch.setattr(metrics, "Counter", counter_factory)
producer_metrics = metrics.ProducerMetrics("proc", "flow", "output")
producer_metrics.inc()
counter_factory.assert_called_once()
labels.inc.assert_called_once_with()
def test_processor_metrics_reports_info(monkeypatch):
info_factory = MagicMock()
labels = MagicMock()
info_factory.return_value.labels.return_value = labels
monkeypatch.setattr(metrics, "Info", info_factory)
processor_metrics = metrics.ProcessorMetrics("proc")
processor_metrics.info({"kind": "test"})
info_factory.assert_called_once()
labels.info.assert_called_once_with({"kind": "test"})
def test_subscriber_metrics_tracks_received_state_and_dropped(monkeypatch):
enum_factory = MagicMock()
counter_factory = MagicMock()
state_labels = MagicMock()
received_labels = MagicMock()
dropped_labels = MagicMock()
enum_factory.return_value.labels.return_value = state_labels
counter_factory.side_effect = [
MagicMock(labels=MagicMock(return_value=received_labels)),
MagicMock(labels=MagicMock(return_value=dropped_labels)),
]
monkeypatch.setattr(metrics, "Enum", enum_factory)
monkeypatch.setattr(metrics, "Counter", counter_factory)
subscriber_metrics = metrics.SubscriberMetrics("proc", "flow", "input")
subscriber_metrics.received()
subscriber_metrics.state("running")
subscriber_metrics.dropped("ignored")
received_labels.inc.assert_called_once_with()
dropped_labels.inc.assert_called_once_with()
state_labels.state.assert_called_once_with("running")

View file

@ -236,6 +236,10 @@ async def test_subscriber_graceful_shutdown():
with patch.object(subscriber, 'run') as mock_run:
# Mock run that simulates graceful shutdown
async def mock_run_graceful():
# Honor the readiness contract: real run() signals _ready
# after binding the consumer, so start() can unblock. Mocks
# of run() must do the same or start() hangs forever.
subscriber._ready.set_result(None)
# Process messages while running, then drain
while subscriber.running or subscriber.draining:
if subscriber.draining:
@ -337,6 +341,8 @@ async def test_subscriber_pending_acks_cleanup():
with patch.object(subscriber, 'run') as mock_run:
# Mock run that simulates cleanup of pending acks
async def mock_run_cleanup():
# Honor the readiness contract — see test_subscriber_graceful_shutdown.
subscriber._ready.set_result(None)
while subscriber.running or subscriber.draining:
await asyncio.sleep(0.05)
if subscriber.draining:
@ -406,4 +412,4 @@ async def test_subscriber_multiple_subscribers():
msg1 = await queue1.get()
msg_all = await queue_all.get()
assert msg1 == {"data": "broadcast"}
assert msg_all == {"data": "broadcast"}
assert msg_all == {"data": "broadcast"}

View file

@ -0,0 +1,189 @@
"""
Regression tests for Subscriber.start() readiness barrier.
Background: prior to the eager-connect fix, Subscriber.start() created
the run() task and returned immediately. The underlying backend consumer
was lazily connected on its first receive() call, which left a setup
race for request/response clients using ephemeral per-subscriber response
queues (RabbitMQ auto-delete exclusive queues): the request would be
published before the response queue was bound, and the broker would
silently drop the reply. fetch_config(), document-embeddings, and
api-gateway all hit this with "Failed to fetch config on notify" /
"Request timeout exception" symptoms.
These tests pin the readiness contract:
await subscriber.start()
# at this point, consumer.ensure_connected() MUST have run
so that any future change which removes the eager bind, or moves it
back to lazy initialisation, fails CI loudly.
"""
import asyncio
import pytest
from unittest.mock import MagicMock
from trustgraph.base.subscriber import Subscriber
def _make_backend(ensure_connected_side_effect=None,
receive_side_effect=None):
"""Build a fake backend whose consumer records ensure_connected /
receive calls. ensure_connected_side_effect lets a test inject a
delay or exception."""
backend = MagicMock()
consumer = MagicMock()
consumer.ensure_connected = MagicMock(
side_effect=ensure_connected_side_effect,
)
# By default receive raises a timeout-style exception that the
# subscriber loop is supposed to swallow as a "no message yet" — this
# keeps the subscriber idling cleanly while the test inspects state.
if receive_side_effect is None:
receive_side_effect = TimeoutError("No message received within timeout")
consumer.receive = MagicMock(side_effect=receive_side_effect)
consumer.acknowledge = MagicMock()
consumer.negative_acknowledge = MagicMock()
consumer.pause_message_listener = MagicMock()
consumer.unsubscribe = MagicMock()
consumer.close = MagicMock()
backend.create_consumer.return_value = consumer
return backend, consumer
def _make_subscriber(backend):
return Subscriber(
backend=backend,
topic="response:tg:config",
subscription="test-sub",
consumer_name="test-consumer",
schema=dict,
max_size=10,
drain_timeout=1.0,
backpressure_strategy="block",
)
class TestSubscriberReadiness:
@pytest.mark.asyncio
async def test_start_calls_ensure_connected_before_returning(self):
"""The barrier: ensure_connected must have been invoked at least
once by the time start() returns."""
backend, consumer = _make_backend()
subscriber = _make_subscriber(backend)
await subscriber.start()
try:
consumer.ensure_connected.assert_called_once()
finally:
await subscriber.stop()
@pytest.mark.asyncio
async def test_start_blocks_until_ensure_connected_completes(self):
"""If ensure_connected is slow, start() must wait for it. This is
the actual race-condition guard it would have failed against
the buggy version where start() returned before run() had even
scheduled the consumer creation."""
connect_started = asyncio.Event()
release_connect = asyncio.Event()
# ensure_connected runs in the executor thread, so we need a
# threading-safe gate. Use a simple busy-wait on a flag set by
# the asyncio side via call_soon_threadsafe — but the simpler
# path is to give it a sleep and observe ordering.
import threading
gate = threading.Event()
def slow_connect():
connect_started.set() # safe: only mutates the Event flag
gate.wait(timeout=2.0)
backend, consumer = _make_backend(
ensure_connected_side_effect=slow_connect,
)
subscriber = _make_subscriber(backend)
start_task = asyncio.create_task(subscriber.start())
# Wait until ensure_connected has begun executing.
await asyncio.wait_for(connect_started.wait(), timeout=2.0)
# ensure_connected is in flight — start() must NOT have returned.
assert not start_task.done(), (
"start() returned before ensure_connected() completed — "
"the readiness barrier is broken and the request/response "
"race condition is back."
)
# Release the gate; start() should now complete promptly.
gate.set()
await asyncio.wait_for(start_task, timeout=2.0)
consumer.ensure_connected.assert_called_once()
await subscriber.stop()
@pytest.mark.asyncio
async def test_start_propagates_consumer_creation_failure(self):
"""If create_consumer() raises, start() must surface the error
rather than hang on the readiness future. The old code path
retried indefinitely inside run() and never let start() unblock."""
backend = MagicMock()
backend.create_consumer.side_effect = RuntimeError("broker down")
subscriber = _make_subscriber(backend)
with pytest.raises(RuntimeError, match="broker down"):
await asyncio.wait_for(subscriber.start(), timeout=2.0)
@pytest.mark.asyncio
async def test_start_propagates_ensure_connected_failure(self):
"""Same contract for an ensure_connected() that raises (e.g. the
broker is up but the queue declare/bind fails)."""
backend, consumer = _make_backend(
ensure_connected_side_effect=RuntimeError("queue declare failed"),
)
subscriber = _make_subscriber(backend)
with pytest.raises(RuntimeError, match="queue declare failed"):
await asyncio.wait_for(subscriber.start(), timeout=2.0)
@pytest.mark.asyncio
async def test_ensure_connected_runs_before_subscriber_running_log(self):
"""Subtle ordering: ensure_connected MUST happen before the
receive loop, so that any reply is captured. We assert this by
checking ensure_connected was called before any receive call."""
call_order = []
def record_ensure():
call_order.append("ensure_connected")
def record_receive(*args, **kwargs):
call_order.append("receive")
raise TimeoutError("No message received within timeout")
backend, consumer = _make_backend(
ensure_connected_side_effect=record_ensure,
receive_side_effect=record_receive,
)
subscriber = _make_subscriber(backend)
await subscriber.start()
# Give the receive loop a tick to run at least once.
await asyncio.sleep(0.05)
await subscriber.stop()
# ensure_connected must come first; receive may not have happened
# yet on a fast machine, but if it did, it must come after.
assert call_order, "neither ensure_connected nor receive was called"
assert call_order[0] == "ensure_connected"

View file

@ -28,7 +28,6 @@ def sample_text_document():
"""Sample document with moderate length text."""
metadata = Metadata(
id="test-doc-1",
user="test-user",
collection="test-collection"
)
text = "The quick brown fox jumps over the lazy dog. " * 20
@ -43,7 +42,6 @@ def long_text_document():
"""Long document for testing multiple chunks."""
metadata = Metadata(
id="test-doc-long",
user="test-user",
collection="test-collection"
)
# Create a long text that will definitely be chunked
@ -59,7 +57,6 @@ def unicode_text_document():
"""Document with various unicode characters."""
metadata = Metadata(
id="test-doc-unicode",
user="test-user",
collection="test-collection"
)
text = """
@ -84,7 +81,6 @@ def empty_text_document():
"""Empty document for edge case testing."""
metadata = Metadata(
id="test-doc-empty",
user="test-user",
collection="test-collection"
)
return TextDocument(

View file

@ -70,11 +70,12 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
# Flow exposes parameter lookup via __call__: flow("chunk-size")
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -105,10 +106,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -139,10 +140,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -184,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-123",
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content"
@ -195,15 +195,15 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
# Flow.__call__ resolves parameters and producers/consumers from the
# same dict — merge both kinds here.
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 1500,
"chunk-overlap": 150,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(name)
}.get(key)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -241,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.return_value = None # No overrides
mock_flow.side_effect = lambda key: None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -70,11 +70,12 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
# Flow exposes parameter lookup via __call__: flow("chunk-size")
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -105,10 +106,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -139,10 +140,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -184,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-456",
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content for token chunking"
@ -195,15 +195,15 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
# Flow.__call__ resolves parameters and producers/consumers from the
# same dict — merge both kinds here.
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 400,
"chunk-overlap": 40,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(name)
}.get(key)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -245,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.return_value = None # No overrides
mock_flow.side_effect = lambda key: None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -109,7 +109,8 @@ class TestListConfigItems:
url='http://custom.com',
config_type='prompt',
format_type='json',
token=None
token=None,
workspace='default'
)
def test_list_main_uses_defaults(self):
@ -128,7 +129,8 @@ class TestListConfigItems:
url='http://localhost:8088/',
config_type='prompt',
format_type='text',
token=None
token=None,
workspace='default'
)
@ -196,7 +198,8 @@ class TestGetConfigItem:
config_type='prompt',
key='template-1',
format_type='json',
token=None
token=None,
workspace='default'
)
@ -253,7 +256,8 @@ class TestPutConfigItem:
config_type='prompt',
key='new-template',
value='Custom prompt: {input}',
token=None
token=None,
workspace='default'
)
def test_put_main_with_stdin_arg(self):
@ -278,7 +282,8 @@ class TestPutConfigItem:
config_type='prompt',
key='stdin-template',
value=stdin_content,
token=None
token=None,
workspace='default'
)
def test_put_main_mutually_exclusive_args(self):
@ -334,7 +339,8 @@ class TestDeleteConfigItem:
url='http://custom.com',
config_type='prompt',
key='old-template',
token=None
token=None,
workspace='default'
)

View file

@ -48,7 +48,7 @@ def knowledge_loader():
return KnowledgeLoader(
files=["test.ttl"],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc-123",
url="http://test.example.com/",
@ -64,7 +64,7 @@ class TestKnowledgeLoader:
loader = KnowledgeLoader(
files=["file1.ttl", "file2.ttl"],
flow="my-flow",
user="user1",
workspace="user1",
collection="col1",
document_id="doc1",
url="http://example.com/",
@ -73,7 +73,7 @@ class TestKnowledgeLoader:
assert loader.files == ["file1.ttl", "file2.ttl"]
assert loader.flow == "my-flow"
assert loader.user == "user1"
assert loader.workspace == "user1"
assert loader.collection == "col1"
assert loader.document_id == "doc1"
assert loader.url == "http://example.com/"
@ -126,7 +126,7 @@ ex:mary ex:knows ex:bob .
loader = KnowledgeLoader(
files=[f.name],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/"
@ -151,7 +151,7 @@ ex:mary ex:knows ex:bob .
loader = KnowledgeLoader(
files=[temp_turtle_file],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/",
@ -163,7 +163,8 @@ ex:mary ex:knows ex:bob .
# Verify Api was created with correct parameters
mock_api_class.assert_called_once_with(
url="http://test.example.com/",
token="test-token"
token="test-token",
workspace="test-user"
)
# Verify bulk client was obtained
@ -174,7 +175,6 @@ ex:mary ex:knows ex:bob .
call_args = mock_bulk.import_triples.call_args
assert call_args[1]['flow'] == "test-flow"
assert call_args[1]['metadata']['id'] == "test-doc"
assert call_args[1]['metadata']['user'] == "test-user"
assert call_args[1]['metadata']['collection'] == "test-collection"
# Verify import_entity_contexts was called
@ -198,7 +198,7 @@ class TestCLIArgumentParsing:
'tg-load-knowledge',
'-i', 'doc-123',
'-f', 'my-flow',
'-U', 'my-user',
'-w', 'my-user',
'-C', 'my-collection',
'-u', 'http://custom.example.com/',
'-t', 'my-token',
@ -216,7 +216,7 @@ class TestCLIArgumentParsing:
token='my-token',
flow='my-flow',
files=['file1.ttl', 'file2.ttl'],
user='my-user',
workspace='my-user',
collection='my-collection'
)
@ -242,7 +242,7 @@ class TestCLIArgumentParsing:
# Verify defaults were used
call_args = mock_loader_class.call_args[1]
assert call_args['flow'] == 'default'
assert call_args['user'] == 'trustgraph'
assert call_args['workspace'] == 'default'
assert call_args['collection'] == 'default'
assert call_args['url'] == 'http://localhost:8088/'
assert call_args['token'] is None
@ -287,7 +287,7 @@ class TestErrorHandling:
loader = KnowledgeLoader(
files=[temp_turtle_file],
flow="test-flow",
user="test-user",
workspace="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/"

View file

@ -145,7 +145,8 @@ class TestSetToolStructuredQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_set_main_structured_query_no_arguments_needed(self):
@ -326,7 +327,8 @@ class TestSetToolRowEmbeddingsQuery:
group=None,
state=None,
applicable_states=None,
token=None
token=None,
workspace='default'
)
def test_valid_types_includes_row_embeddings_query(self):
@ -471,7 +473,7 @@ class TestShowToolsStructuredQuery:
show_main()
mock_show.assert_called_once_with(url='http://custom.com', token=None)
mock_show.assert_called_once_with(url='http://custom.com', token=None, workspace='default')
class TestShowToolsRowEmbeddingsQuery:

View file

@ -73,7 +73,6 @@ class TestSyncDocumentEmbeddingsClient:
# Act
result = client.request(
vector=vector,
user="test_user",
collection="test_collection",
limit=10,
timeout=300
@ -82,7 +81,6 @@ class TestSyncDocumentEmbeddingsClient:
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with(
user="test_user",
collection="test_collection",
vector=vector,
limit=10,
@ -108,7 +106,6 @@ class TestSyncDocumentEmbeddingsClient:
# Assert
assert result == ["test_chunk"]
client.call.assert_called_once_with(
user="trustgraph",
collection="default",
vector=vector,
limit=10,

View file

@ -31,7 +31,6 @@ def _make_query(
query = Query(
rag=rag,
user="test-user",
collection="test-collection",
verbose=False,
entity_limit=entity_limit,
@ -208,7 +207,6 @@ class TestBatchTripleQueries:
assert calls[0].kwargs["p"] is None
assert calls[0].kwargs["o"] is None
assert calls[0].kwargs["limit"] == 15
assert calls[0].kwargs["user"] == "test-user"
assert calls[0].kwargs["collection"] == "test-collection"
assert calls[0].kwargs["batch_size"] == 20

View file

@ -28,10 +28,12 @@ def mock_flow_config():
"""Mock flow configuration."""
mock_config = Mock()
mock_config.flows = {
"test-flow": {
"interfaces": {
"triples-store": "test-triples-queue",
"graph-embeddings-store": "test-ge-queue"
"test-user": {
"test-flow": {
"interfaces": {
"triples-store": {"flow": "test-triples-queue"},
"graph-embeddings-store": {"flow": "test-ge-queue"}
}
}
}
}
@ -43,7 +45,7 @@ def mock_flow_config():
def mock_request():
"""Mock knowledge load request."""
request = Mock()
request.user = "test-user"
request.workspace = "test-user"
request.id = "test-doc-id"
request.collection = "test-collection"
request.flow = "test-flow"
@ -71,7 +73,6 @@ def sample_triples():
return Triples(
metadata=Metadata(
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
),
triples=[
@ -90,7 +91,6 @@ def sample_graph_embeddings():
return GraphEmbeddings(
metadata=Metadata(
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
),
entities=[
@ -146,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1]
assert sent_triples.metadata.collection == "test-collection"
assert sent_triples.metadata.user == "test-user"
assert sent_triples.metadata.id == "test-doc-id"
@pytest.mark.asyncio
@ -185,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
mock_ge_pub.send.assert_called_once()
sent_ge = mock_ge_pub.send.call_args[0][1]
assert sent_ge.metadata.collection == "test-collection"
assert sent_ge.metadata.user == "test-user"
assert sent_ge.metadata.id == "test-doc-id"
@pytest.mark.asyncio
@ -193,7 +191,7 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core falls back to 'default' when request.collection is None."""
# Create request with None collection
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_request.collection = None # Should fall back to "default"
mock_request.flow = "test-flow"
@ -269,7 +267,7 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core validates flow configuration before processing."""
# Request with invalid flow
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_request.collection = "test-collection"
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
@ -297,7 +295,7 @@ class TestKnowledgeManagerLoadCore:
# Test missing ID
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = None # Missing
mock_request.collection = "test-collection"
mock_request.flow = "test-flow"
@ -323,7 +321,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
"""Test that get_kg_core preserves collection field from stored data."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_respond = AsyncMock()
@ -354,7 +352,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_list_kg_cores(self, knowledge_manager):
"""Test listing knowledge cores."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_respond = AsyncMock()
@ -376,7 +374,7 @@ class TestKnowledgeManagerOtherMethods:
async def test_delete_kg_core(self, knowledge_manager):
"""Test deleting knowledge cores."""
mock_request = Mock()
mock_request.user = "test-user"
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id"
mock_respond = AsyncMock()

View file

@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message with inline data
content = b"# Document Title\nBody text content."
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
]
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,

View file

@ -12,7 +12,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_basic(self):
"""Test basic collection name creation"""
result = make_safe_collection_name(
user="test_user",
workspace="test_user",
collection="test_collection",
prefix="doc"
)
@ -21,7 +21,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_special_characters(self):
"""Test collection name creation with special characters that need sanitization"""
result = make_safe_collection_name(
user="user@domain.com",
workspace="user@domain.com",
collection="test-collection.v2",
prefix="entity"
)
@ -30,7 +30,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_unicode(self):
"""Test collection name creation with Unicode characters"""
result = make_safe_collection_name(
user="测试用户",
workspace="测试用户",
collection="colección_española",
prefix="doc"
)
@ -39,7 +39,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_spaces(self):
"""Test collection name creation with spaces"""
result = make_safe_collection_name(
user="test user",
workspace="test user",
collection="my test collection",
prefix="entity"
)
@ -48,7 +48,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self):
"""Test collection name creation with multiple consecutive special characters"""
result = make_safe_collection_name(
user="user@@@domain!!!",
workspace="user@@@domain!!!",
collection="test---collection...v2",
prefix="doc"
)
@ -57,7 +57,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_with_leading_trailing_underscores(self):
"""Test collection name creation with leading/trailing special characters"""
result = make_safe_collection_name(
user="__test_user__",
workspace="__test_user__",
collection="@@test_collection##",
prefix="entity"
)
@ -66,7 +66,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_empty_user(self):
"""Test collection name creation with empty user (should fallback to 'default')"""
result = make_safe_collection_name(
user="",
workspace="",
collection="test_collection",
prefix="doc"
)
@ -75,7 +75,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_empty_collection(self):
"""Test collection name creation with empty collection (should fallback to 'default')"""
result = make_safe_collection_name(
user="test_user",
workspace="test_user",
collection="",
prefix="doc"
)
@ -84,7 +84,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_both_empty(self):
"""Test collection name creation with both user and collection empty"""
result = make_safe_collection_name(
user="",
workspace="",
collection="",
prefix="doc"
)
@ -93,7 +93,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_only_special_characters(self):
"""Test collection name creation with only special characters (should fallback to 'default')"""
result = make_safe_collection_name(
user="@@@!!!",
workspace="@@@!!!",
collection="---###",
prefix="entity"
)
@ -102,7 +102,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_whitespace_only(self):
"""Test collection name creation with whitespace-only strings"""
result = make_safe_collection_name(
user=" \n\t ",
workspace=" \n\t ",
collection=" \r\n ",
prefix="doc"
)
@ -111,7 +111,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_mixed_valid_invalid_chars(self):
"""Test collection name creation with mixed valid and invalid characters"""
result = make_safe_collection_name(
user="user123@test",
workspace="user123@test",
collection="coll_2023.v1",
prefix="entity"
)
@ -147,7 +147,7 @@ class TestMilvusCollectionNaming:
long_collection = "b" * 100
result = make_safe_collection_name(
user=long_user,
workspace=long_user,
collection=long_collection,
prefix="doc"
)
@ -159,7 +159,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_numeric_values(self):
"""Test collection name creation with numeric user/collection values"""
result = make_safe_collection_name(
user="user123",
workspace="user123",
collection="collection456",
prefix="doc"
)
@ -168,7 +168,7 @@ class TestMilvusCollectionNaming:
def test_make_safe_collection_name_case_sensitivity(self):
"""Test that collection name creation preserves case"""
result = make_safe_collection_name(
user="TestUser",
workspace="TestUser",
collection="TestCollection",
prefix="Doc"
)

View file

@ -20,9 +20,8 @@ def processor():
)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, collection=collection)
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
msg = MagicMock()
msg.value.return_value = value
@ -127,7 +126,7 @@ class TestDocumentEmbeddingsProcessor:
@pytest.mark.asyncio
async def test_metadata_preserved(self, processor):
"""Output should carry the original metadata."""
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
msg = _make_chunk_message(collection="reports", doc_id="d1")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]]
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.metadata.user == "alice"
assert result.metadata.collection == "reports"
assert result.metadata.id == "d1"

View file

@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
def _make_message(entities, doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, collection=collection)
value = EntityContexts(metadata=metadata, entities=entities)
msg = MagicMock()
msg.value.return_value = value
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
_make_entity_context(f"E{i}", f"ctx {i}")
for i in range(5)
]
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
msg = _make_message(entities, doc_id="doc-42", collection="main")
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
mock_output = AsyncMock()
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
for call in mock_output.send.call_args_list:
result = call[0][0]
assert result.metadata.id == "doc-42"
assert result.metadata.user == "alice"
assert result.metadata.collection == "main"
@pytest.mark.asyncio

View file

@ -214,11 +214,11 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
assert 'customers' in processor.schemas["default"]
assert processor.schemas["default"]['customers'].name == 'customers'
assert len(processor.schemas["default"]['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
@ -236,9 +236,9 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
await processor.on_schema_config("default", config_data, 1)
assert processor.schemas == {}
assert processor.schemas.get("default", {}) == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
@ -285,7 +285,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('default', 'test_collection')] = {}
# No schemas registered
metadata = MagicMock()
@ -322,17 +322,19 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('default', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
processor.schemas["default"] = {
'customers': RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
}
metadata = MagicMock()
metadata.user = 'test_user'
@ -372,6 +374,7 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
mock_flow.workspace = "default"
await processor.on_message(mock_msg, MagicMock(), mock_flow)

View file

@ -0,0 +1,200 @@
"""
Unit tests for extract_with_simplified_format.
Regression guard for the bug where the extractor read
``result.object`` (singular, used for response_type="json") instead of
``result.objects`` (plural, used for response_type="jsonl"). The
extract-with-ontologies prompt is JSONL, so reading the wrong field
silently dropped every extraction and left the knowledge graph
populated only by ontology schema + document provenance.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.base import PromptResult
@pytest.fixture
def extractor():
"""Create a Processor instance without running its heavy __init__.
Matches the pattern used in test_prompt_and_extraction.py: only
the attributes the code under test touches need to be set.
"""
ex = object.__new__(Processor)
ex.URI_PREFIXES = {
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
"owl:": "http://www.w3.org/2002/07/owl#",
"xsd:": "http://www.w3.org/2001/XMLSchema#",
}
return ex
@pytest.fixture
def food_subset():
"""A minimal food ontology subset the extracted entities reference."""
return OntologySubset(
ontology_id="food",
classes={
"Recipe": {
"uri": "http://purl.org/ontology/fo/Recipe",
"type": "owl:Class",
"labels": [{"value": "Recipe", "lang": "en-gb"}],
"comment": "A Recipe.",
},
"Food": {
"uri": "http://purl.org/ontology/fo/Food",
"type": "owl:Class",
"labels": [{"value": "Food", "lang": "en-gb"}],
"comment": "A Food.",
},
},
object_properties={
"ingredients": {
"uri": "http://purl.org/ontology/fo/ingredients",
"type": "owl:ObjectProperty",
"labels": [{"value": "ingredients", "lang": "en-gb"}],
"comment": "Relates a recipe to its ingredients.",
"domain": "Recipe",
"range": "Food",
},
},
datatype_properties={},
metadata={
"name": "Food Ontology",
"namespace": "http://purl.org/ontology/fo/",
},
)
def _flow_with_prompt_result(prompt_result):
"""Build the ``flow(name)`` callable the extractor invokes.
``extract_with_simplified_format`` calls
``flow("prompt-request").prompt(...)`` so we need ``flow`` to be
callable, return an object whose ``.prompt`` is an AsyncMock that
resolves to ``prompt_result``.
"""
prompt_service = MagicMock()
prompt_service.prompt = AsyncMock(return_value=prompt_result)
def flow(name):
assert name == "prompt-request", (
f"extractor should only invoke flow('prompt-request'), "
f"got {name!r}"
)
return prompt_service
return flow, prompt_service.prompt
class TestReadsObjectsForJsonlPrompt:
"""extract-with-ontologies is a JSONL prompt; the extractor must
read ``result.objects``, not ``result.object``."""
async def test_populated_objects_produces_triples(
self, extractor, food_subset,
):
"""Happy path: PromptResult with populated .objects -> non-empty
triples list."""
prompt_result = PromptResult(
response_type="jsonl",
objects=[
{"type": "entity", "entity": "Cornish Pasty",
"entity_type": "Recipe"},
{"type": "entity", "entity": "beef",
"entity_type": "Food"},
{"type": "relationship",
"subject": "Cornish Pasty", "subject_type": "Recipe",
"relation": "ingredients",
"object": "beef", "object_type": "Food"},
],
)
flow, prompt_mock = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "some chunk", food_subset, {"text": "some chunk"},
)
prompt_mock.assert_awaited_once()
assert triples, (
"extract_with_simplified_format returned no triples; if "
"this fails, the extractor is probably reading .object "
"instead of .objects again"
)
async def test_none_objects_returns_empty_without_crashing(
self, extractor, food_subset,
):
"""The exact shape that hit production on v2.3: the extractor
was reading ``.object`` for a JSONL prompt, which returned
``None`` and tripped the parser's 'Unexpected response type'
path. With the fix we read ``.objects``; if that's also
``None`` we must still return ``[]`` cleanly, not crash."""
prompt_result = PromptResult(
response_type="jsonl",
objects=None,
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == []
async def test_empty_objects_returns_empty(
self, extractor, food_subset,
):
"""Valid JSONL response with zero entries should yield zero
triples, not raise."""
prompt_result = PromptResult(
response_type="jsonl",
objects=[],
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == []
async def test_ignores_object_field_for_jsonl_prompt(
self, extractor, food_subset,
):
"""If ``.object`` is somehow set but ``.objects`` is None, the
extractor must not silently fall back to ``.object``. This
guards against a well-meaning regression that "helpfully"
re-adds fallback fields.
The extractor should read only ``.objects`` for this prompt;
when that is None we expect the empty-result path.
"""
prompt_result = PromptResult(
response_type="json",
object={"not": "the field we should be reading"},
objects=None,
)
flow, _ = _flow_with_prompt_result(prompt_result)
triples = await extractor.extract_with_simplified_format(
flow, "chunk", food_subset, {"text": "chunk"},
)
assert triples == [], (
"Extractor fell back to .object for a JSONL prompt — "
"this is the regression shape we are trying to prevent"
)

View file

@ -231,6 +231,52 @@ class TestTripleValidation:
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset)
assert is_valid == expected, f"Validation of {predicate} should be {expected}"
def test_validates_domain_correctly_with_entity_types(self, extractor, sample_ontology_subset):
"""Test domain validation correctly compares against extracted entity_types."""
subject = "my-recipe"
predicate = "produces"
object_val = "my-food"
# Proper domain for produces is Recipe
entity_types = {
"my-recipe": "Recipe",
"my-food": "Food"
}
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types)
assert is_valid, "Valid domain should be accepted"
# Invalid domain
entity_types_invalid = {
"my-recipe": "Ingredient",
"my-food": "Food"
}
is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid)
assert not is_invalid, "Invalid domain should be rejected"
def test_validates_range_correctly_with_entity_types(self, extractor, sample_ontology_subset):
"""Test range validation correctly compares against extracted entity_types."""
subject = "my-recipe"
predicate = "produces"
object_val = "my-food"
# Proper range for produces is Food
entity_types = {
"my-recipe": "Recipe",
"my-food": "Food"
}
is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types)
assert is_valid, "Valid range should be accepted"
# Invalid range
entity_types_invalid = {
"my-recipe": "Recipe",
"my-food": "Recipe"
}
is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid)
assert not is_invalid, "Invalid range should be rejected"
class TestTripleParsing:
"""Test suite for parsing triples from LLM responses."""

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size,
)
from trustgraph.base import PromptResult
from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
)
@ -33,11 +34,10 @@ def _make_defn(entity, definition):
return {"entity": entity, "definition": definition}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
id=meta_id, root=root, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
@ -51,8 +51,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock()
if isinstance(prompt_result, list):
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
else:
wrapped = PromptResult(response_type="text", text=prompt_result)
mock_prompt_client.extract_definitions = AsyncMock(
return_value=prompt_result
return_value=wrapped
)
def flow(name):
@ -224,8 +228,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
"text", meta_id="c-1", root="r-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
@ -233,7 +236,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(triples_pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"
@pytest.mark.asyncio
@ -242,8 +244,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-2", root="r-2",
user="u-2", collection="coll-2",
"text", meta_id="c-2", root="r-2", collection="coll-2",
)
await proc.on_message(msg, MagicMock(), flow)

View file

@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
# ---------------------------------------------------------------------------
@ -37,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
"""Build a mock message wrapping a Chunk."""
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
id=meta_id, root=root, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
@ -58,7 +58,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock(
return_value=prompt_result
return_value=PromptResult(
response_type="jsonl",
objects=prompt_result,
)
)
def flow(name):
@ -185,8 +188,7 @@ class TestMetadataPreservation:
rels = [_make_rel("X", "rel", "Y")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
"text", meta_id="c-1", root="r-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
@ -194,7 +196,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"

View file

@ -17,6 +17,12 @@ _real_config_loader = ConfigReceiver.config_loader
ConfigReceiver.config_loader = Mock()
def _notify(version, changes):
msg = Mock()
msg.value.return_value = Mock(version=version, changes=changes)
return msg
class TestConfigReceiver:
"""Test cases for ConfigReceiver class"""
@ -47,98 +53,70 @@ class TestConfigReceiver:
assert handler2 in config_receiver.flow_handlers
@pytest.mark.asyncio
async def test_on_config_notify_new_version(self):
"""Test on_config_notify triggers fetch for newer version"""
async def test_on_config_notify_new_version_fetches_per_workspace(self):
"""Notify with newer version fetches each affected workspace."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
# Mock fetch_and_apply
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with newer version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 1
msg = _notify(2, {"flow": ["ws1", "ws2"]})
await config_receiver.on_config_notify(msg, None, None)
assert set(fetch_calls) == {"ws1", "ws2"}
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_old_version_ignored(self):
"""Test on_config_notify ignores older versions"""
"""Older-version notifies are ignored."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 5
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with older version
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=3, types=["flow"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
assert len(fetch_calls) == 0
msg = _notify(3, {"flow": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
@pytest.mark.asyncio
async def test_on_config_notify_irrelevant_types_ignored(self):
"""Test on_config_notify ignores types the gateway doesn't care about"""
"""Notifies without flow changes advance version but skip fetch."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with non-flow type
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
async def mock_fetch(workspace, retry=False):
fetch_calls.append(workspace)
await config_receiver.on_config_notify(mock_msg, None, None)
config_receiver.fetch_and_apply_workspace = mock_fetch
# Version should be updated but no fetch
assert len(fetch_calls) == 0
msg = _notify(2, {"prompt": ["ws1"]})
await config_receiver.on_config_notify(msg, None, None)
assert fetch_calls == []
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_notify_flow_type_triggers_fetch(self):
"""Test on_config_notify fetches for flow-related types"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
for type_name in ["flow", "active-flow"]:
fetch_calls.clear()
config_receiver.config_version = 1
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=[type_name])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
@pytest.mark.asyncio
async def test_on_config_notify_exception_handling(self):
"""Test on_config_notify handles exceptions gracefully"""
"""on_config_notify swallows exceptions from message decode."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create notify message that causes an exception
mock_msg = Mock()
mock_msg.value.side_effect = Exception("Test exception")
@ -146,19 +124,18 @@ class TestConfigReceiver:
await config_receiver.on_config_notify(mock_msg, None, None)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_new_flows(self):
"""Test fetch_and_apply starts new flows"""
async def test_fetch_and_apply_workspace_starts_new_flows(self):
"""fetch_and_apply_workspace starts newly-configured flows."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock _create_config_client to return a mock client
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}',
"flow2": '{"name": "test_flow_2"}'
"flow2": '{"name": "test_flow_2"}',
}
}
@ -167,36 +144,39 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
start_flow_calls = []
async def mock_start_flow(id, flow):
start_flow_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_flow_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.config_version == 5
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" in config_receiver.flows["default"]
assert len(start_flow_calls) == 2
assert all(c[0] == "default" for c in start_flow_calls)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_removed_flows(self):
"""Test fetch_and_apply stops removed flows"""
async def test_fetch_and_apply_workspace_stops_removed_flows(self):
"""fetch_and_apply_workspace stops flows no longer configured."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config now only has flow1
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}'
"flow1": '{"name": "test_flow_1"}',
}
}
@ -205,20 +185,22 @@ class TestConfigReceiver:
config_receiver._create_config_client = Mock(return_value=mock_client)
stop_flow_calls = []
async def mock_stop_flow(id, flow):
stop_flow_calls.append((id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_flow_calls.append((workspace, id, flow))
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
assert "flow1" in config_receiver.flows["default"]
assert "flow2" not in config_receiver.flows["default"]
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0][0] == "flow2"
assert stop_flow_calls[0][:2] == ("default", "flow2")
@pytest.mark.asyncio
async def test_fetch_and_apply_with_no_flows(self):
"""Test fetch_and_apply with empty config"""
async def test_fetch_and_apply_workspace_with_no_flows(self):
"""Empty workspace config clears any local flow state."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
@ -231,88 +213,100 @@ class TestConfigReceiver:
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert config_receiver.flows == {}
assert config_receiver.flows.get("default", {}) == {}
assert config_receiver.config_version == 1
@pytest.mark.asyncio
async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers"""
"""start_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.start_flow = Mock()
handler1.start_flow = AsyncMock()
handler2 = Mock()
handler2.start_flow = Mock()
handler2.start_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler1.start_flow.assert_called_once_with("flow1", flow_data)
handler2.start_flow.assert_called_once_with("flow1", flow_data)
handler1.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_start_flow_with_handler_exception(self):
"""Test start_flow method handles handler exceptions"""
"""Handler exceptions in start_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.start_flow = Mock(side_effect=Exception("Handler error"))
handler.start_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.start_flow("flow1", flow_data)
await config_receiver.start_flow("default", "flow1", flow_data)
handler.start_flow.assert_called_once_with("flow1", flow_data)
handler.start_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handlers(self):
"""Test stop_flow method with multiple handlers"""
"""stop_flow fans out to every registered flow handler."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler1.stop_flow = Mock()
handler1.stop_flow = AsyncMock()
handler2 = Mock()
handler2.stop_flow = Mock()
handler2.stop_flow = AsyncMock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
handler1.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
handler2.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@pytest.mark.asyncio
async def test_stop_flow_with_handler_exception(self):
"""Test stop_flow method handles handler exceptions"""
"""Handler exceptions in stop_flow do not propagate."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler = Mock()
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
handler.stop_flow = AsyncMock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# Should not raise
await config_receiver.stop_flow("flow1", flow_data)
await config_receiver.stop_flow("default", "flow1", flow_data)
handler.stop_flow.assert_called_once_with("flow1", flow_data)
handler.stop_flow.assert_awaited_once_with(
"default", "flow1", flow_data
)
@patch('asyncio.create_task')
@pytest.mark.asyncio
@ -329,25 +323,25 @@ class TestConfigReceiver:
mock_create_task.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_and_apply_mixed_flow_operations(self):
"""Test fetch_and_apply with mixed add/remove operations"""
async def test_fetch_and_apply_workspace_mixed_flow_operations(self):
"""fetch_and_apply_workspace adds, keeps and removes flows in one pass."""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
"default": {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"},
}
}
# Config removes flow1, keeps flow2, adds flow3
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow2": '{"name": "test_flow_2"}',
"flow3": '{"name": "test_flow_3"}'
"flow3": '{"name": "test_flow_3"}',
}
}
@ -358,20 +352,22 @@ class TestConfigReceiver:
start_calls = []
stop_calls = []
async def mock_start_flow(id, flow):
start_calls.append((id, flow))
async def mock_stop_flow(id, flow):
stop_calls.append((id, flow))
async def mock_start_flow(workspace, id, flow):
start_calls.append((workspace, id, flow))
async def mock_stop_flow(workspace, id, flow):
stop_calls.append((workspace, id, flow))
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
await config_receiver.fetch_and_apply_workspace("default")
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
ws_flows = config_receiver.flows["default"]
assert "flow1" not in ws_flows
assert "flow2" in ws_flows
assert "flow3" in ws_flows
assert len(start_calls) == 1
assert start_calls[0][0] == "flow3"
assert start_calls[0][:2] == ("default", "flow3")
assert len(stop_calls) == 1
assert stop_calls[0][0] == "flow1"
assert stop_calls[0][:2] == ("default", "flow1")

View file

@ -0,0 +1,406 @@
"""
Round-trip unit tests for the core msgpack import/export gateway endpoints.
The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
the responder callback and packs them into msgpack tuples. The kg-core
import endpoint takes msgpack tuples back off the wire and rebuilds
KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
(whose translator decodes them into real dataclasses).
Regression coverage: the previous wire format used `"vectors"` (plural)
in the entity blobs and embedded a stale `"m"` field that referenced the
removed `Metadata.metadata` triples-list field. The export side hit a
KeyError on first message; the import side built dicts that the
KnowledgeRequestTranslator (separately fixed) couldn't decode. These
tests pin both halves of the wire protocol.
"""
import msgpack
import pytest
from unittest.mock import AsyncMock, Mock, patch
from trustgraph.gateway.dispatch.core_export import CoreExport
from trustgraph.gateway.dispatch.core_import import CoreImport
# ---------------------------------------------------------------------------
# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
# would emit). The vector wire key is *singular* on purpose; the export
# side previously read the wrong key and crashed.
# ---------------------------------------------------------------------------
def _ge_response_dict():
return {
"graph-embeddings": {
"metadata": {
"id": "doc-1",
"root": "",
"collection": "testcoll",
},
"entities": [
{
"entity": {"t": "i", "i": "http://example.org/alice"},
"vector": [0.1, 0.2, 0.3],
},
{
"entity": {"t": "i", "i": "http://example.org/bob"},
"vector": [0.4, 0.5, 0.6],
},
],
}
}
def _triples_response_dict():
return {
"triples": {
"metadata": {
"id": "doc-1",
"root": "",
"collection": "testcoll",
},
"triples": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
"p": {"t": "i", "i": "http://example.org/knows"},
"o": {"t": "i", "i": "http://example.org/bob"},
},
],
}
}
def _make_request(id_="doc-1", workspace="alice"):
request = Mock()
request.query = {"id": id_, "workspace": workspace}
return request
def _make_data_reader(payload: bytes):
"""Mock the aiohttp StreamReader: returns payload once, then EOF."""
chunks = [payload, b""]
data = Mock()
async def fake_read(n):
return chunks.pop(0) if chunks else b""
data.read = fake_read
return data
# ---------------------------------------------------------------------------
# Export side: translator-shaped dict -> msgpack bytes
# ---------------------------------------------------------------------------
class TestCoreExportWireFormat:
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_export_packs_graph_embeddings_with_singular_vector(
self, mock_kr_class,
):
"""The export side must read `ent["vector"]` and emit `v`. The
previous bug was reading `ent["vectors"]` which KeyErrored against
the translator output."""
captured = []
async def fake_kr_process(req_dict, responder):
await responder(_ge_response_dict(), True)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
response = AsyncMock()
async def fake_write(b):
captured.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
ok = AsyncMock(return_value=response)
error = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(),
error=error,
ok=ok,
request=_make_request(),
)
# Did not raise, did not call error()
error.assert_not_called()
assert len(captured) == 1
unpacker = msgpack.Unpacker()
unpacker.feed(captured[0])
items = list(unpacker)
assert len(items) == 1
msg_type, payload = items[0]
assert msg_type == "ge"
# Metadata envelope: only id/collection — no stale `m["m"]`.
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
# Entities: each carries the *singular* `v` and the term envelope
assert len(payload["e"]) == 2
assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_export_packs_triples(self, mock_kr_class):
captured = []
async def fake_kr_process(req_dict, responder):
await responder(_triples_response_dict(), True)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
response = AsyncMock()
async def fake_write(b):
captured.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
ok = AsyncMock(return_value=response)
error = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(), error=error, ok=ok, request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
unpacker = msgpack.Unpacker()
unpacker.feed(captured[0])
items = list(unpacker)
assert len(items) == 1
msg_type, payload = items[0]
assert msg_type == "t"
assert payload["m"] == {"i": "doc-1", "c": "testcoll"}
assert len(payload["t"]) == 1
# ---------------------------------------------------------------------------
# Import side: msgpack bytes -> translator-shaped dict
# ---------------------------------------------------------------------------
class TestCoreImportWireFormat:
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
async def test_import_unpacks_graph_embeddings_to_singular_vector(
self, mock_kr_class,
):
"""The import side must build dicts whose entity blobs have the
singular `vector` key that's what the KnowledgeRequestTranslator
decode side reads. Previous bug emitted `vectors`."""
captured = []
async def fake_kr_process(req_dict):
captured.append(req_dict)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
# Build a msgpack tuple matching the new wire format
payload = msgpack.packb((
"ge",
{
"m": {"i": "doc-1", "c": "testcoll"},
"e": [
{
"e": {"t": "i", "i": "http://example.org/alice"},
"v": [0.1, 0.2, 0.3],
},
],
},
))
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
error = AsyncMock()
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(payload),
error=error,
ok=ok,
request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
req = captured[0]
assert req["operation"] == "put-kg-core"
assert req["workspace"] == "alice"
assert req["id"] == "doc-1"
ge = req["graph-embeddings"]
# Metadata envelope must NOT contain a stale `metadata` key
# referencing the removed Metadata.metadata field.
assert "metadata" not in ge["metadata"]
assert ge["metadata"] == {
"id": "doc-1",
"collection": "default",
}
# Entity blob carries the singular `vector` key
assert len(ge["entities"]) == 1
ent = ge["entities"][0]
assert ent["vector"] == [0.1, 0.2, 0.3]
assert "vectors" not in ent
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
async def test_import_unpacks_triples(self, mock_kr_class):
captured = []
async def fake_kr_process(req_dict):
captured.append(req_dict)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
payload = msgpack.packb((
"t",
{
"m": {"i": "doc-1", "c": "testcoll"},
"t": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
"p": {"t": "i", "i": "http://example.org/knows"},
"o": {"t": "i", "i": "http://example.org/bob"},
},
],
},
))
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
error = AsyncMock()
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(payload),
error=error,
ok=ok,
request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
req = captured[0]
triples = req["triples"]
assert "metadata" not in triples["metadata"] # no stale field
assert len(triples["triples"]) == 1
# ---------------------------------------------------------------------------
# Full round-trip: export bytes feed directly into import
# ---------------------------------------------------------------------------
class TestCoreImportExportRoundTrip:
"""End-to-end: produce bytes via core_export, consume them via
core_import, and verify the dict that lands at the import-side
translator is structurally equivalent to what went in. This is the
test that catches asymmetries between the two halves."""
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_graph_embeddings_round_trip(
self, mock_export_kr_class, mock_import_kr_class,
):
# ----- export side: capture bytes -----
export_bytes = []
async def fake_export_process(req_dict, responder):
await responder(_ge_response_dict(), True)
export_kr = AsyncMock()
export_kr.start = AsyncMock()
export_kr.stop = AsyncMock()
export_kr.process = fake_export_process
mock_export_kr_class.return_value = export_kr
response = AsyncMock()
async def fake_write(b):
export_bytes.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(),
error=AsyncMock(),
ok=AsyncMock(return_value=response),
request=_make_request(),
)
assert len(export_bytes) == 1
# ----- import side: feed those bytes back in -----
import_captured = []
async def fake_import_process(req_dict):
import_captured.append(req_dict)
import_kr = AsyncMock()
import_kr.start = AsyncMock()
import_kr.stop = AsyncMock()
import_kr.process = fake_import_process
mock_import_kr_class.return_value = import_kr
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(export_bytes[0]),
error=AsyncMock(),
ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
request=_make_request(),
)
# ----- verify the dict the importer would hand to the translator -----
assert len(import_captured) == 1
req = import_captured[0]
original = _ge_response_dict()["graph-embeddings"]
ge = req["graph-embeddings"]
# The import side overrides id from the URL query (intentional),
# so we only round-trip the entity payload itself.
assert ge["metadata"]["id"] == original["metadata"]["id"]
assert len(ge["entities"]) == len(original["entities"])
for got, want in zip(ge["entities"], original["entities"]):
assert got["vector"] == want["vector"]
assert got["entity"] == want["entity"]

View file

@ -72,10 +72,10 @@ class TestDispatcherManager:
flow_data = {"name": "test_flow", "steps": []}
await manager.start_flow("flow1", flow_data)
assert "flow1" in manager.flows
assert manager.flows["flow1"] == flow_data
await manager.start_flow("default", "flow1", flow_data)
assert ("default", "flow1") in manager.flows
assert manager.flows[("default", "flow1")] == flow_data
@pytest.mark.asyncio
async def test_stop_flow(self):
@ -86,11 +86,11 @@ class TestDispatcherManager:
# Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []}
manager.flows["flow1"] = flow_data
await manager.stop_flow("flow1", flow_data)
assert "flow1" not in manager.flows
manager.flows[("default", "flow1")] = flow_data
await manager.stop_flow("default", "flow1", flow_data)
assert ("default", "flow1") not in manager.flows
def test_dispatch_global_service_returns_wrapper(self):
"""Test dispatch_global_service returns DispatcherWrapper"""
@ -275,12 +275,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"queue": "test_queue"}
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -290,7 +290,7 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
params = {"flow": "test_flow", "kind": "triples"}
result = await manager.process_flow_import("ws", "running", params)
@ -298,7 +298,7 @@ class TestDispatcherManager:
backend=mock_backend,
ws="ws",
running="running",
queue={"queue": "test_queue"}
queue="test_queue"
)
mock_dispatcher.start.assert_called_once()
assert result == mock_dispatcher
@ -326,12 +326,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"queue": "test_queue"}
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
mock_dispatchers.__contains__.return_value = False
@ -348,12 +348,12 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"triples-store": {"queue": "test_queue"}
"triples-store": {"flow": "test_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
@ -370,7 +370,7 @@ class TestDispatcherManager:
backend=mock_backend,
ws="ws",
running="running",
queue={"queue": "test_queue"},
queue="test_queue",
consumer="api-gateway-test-uuid",
subscriber="api-gateway-test-uuid"
)
@ -404,7 +404,7 @@ class TestDispatcherManager:
params = {"flow": "test_flow", "kind": "agent"}
result = await manager.process_flow_service("data", "responder", params)
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
manager.invoke_flow_service.assert_called_once_with("data", "responder", "default", "test_flow", "agent")
assert result == "flow_result"
@pytest.mark.asyncio
@ -415,14 +415,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Add flow to the flows dictionary
manager.flows["test_flow"] = {"services": {"agent": {}}}
manager.flows[("default", "test_flow")] = {"services": {"agent": {}}}
# Pre-populate with existing dispatcher
mock_dispatcher = Mock()
mock_dispatcher.process = AsyncMock(return_value="cached_result")
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.dispatchers[("default", "test_flow", "agent")] = mock_dispatcher
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
mock_dispatcher.process.assert_called_once_with("data", "responder")
assert result == "cached_result"
@ -435,7 +435,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -443,7 +443,7 @@ class TestDispatcherManager:
}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
@ -452,23 +452,23 @@ class TestDispatcherManager:
mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
mock_dispatchers.__contains__.return_value = True
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
request_queue="agent_request_queue",
response_queue="agent_response_queue",
timeout=120,
consumer="api-gateway-test_flow-agent-request",
subscriber="api-gateway-test_flow-agent-request"
consumer="api-gateway-default-test_flow-agent-request",
subscriber="api-gateway-default-test_flow-agent-request"
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] == mock_dispatcher
assert result == "new_result"
@pytest.mark.asyncio
@ -479,36 +479,36 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-load": {"queue": "text_load_queue"}
"text-load": {"flow": "text_load_queue"}
}
}
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
mock_rr_dispatchers.__contains__.return_value = False
mock_sender_dispatchers.__contains__.return_value = True
mock_dispatcher_class = Mock()
mock_dispatcher = Mock()
mock_dispatcher.start = AsyncMock()
mock_dispatcher.process = AsyncMock(return_value="sender_result")
mock_dispatcher_class.return_value = mock_dispatcher
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
result = await manager.invoke_flow_service("data", "responder", "default", "test_flow", "text-load")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with(
backend=mock_backend,
queue={"queue": "text_load_queue"}
queue="text_load_queue"
)
mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "text-load")] == mock_dispatcher
assert result == "sender_result"
@pytest.mark.asyncio
@ -519,7 +519,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
@ -529,14 +529,14 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow without agent interface
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"text-completion": {"request": "req", "response": "resp"}
}
}
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
@pytest.mark.asyncio
async def test_invoke_flow_service_invalid_kind(self):
@ -546,7 +546,7 @@ class TestDispatcherManager:
manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow with interface but unsupported kind
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"invalid-kind": {"request": "req", "response": "resp"}
}
@ -558,7 +558,7 @@ class TestDispatcherManager:
mock_sender_dispatchers.__contains__.return_value = False
with pytest.raises(RuntimeError, match="Invalid kind"):
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
await manager.invoke_flow_service("data", "responder", "default", "test_flow", "invalid-kind")
@pytest.mark.asyncio
async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self):
@ -608,7 +608,7 @@ class TestDispatcherManager:
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.flows["test_flow"] = {
manager.flows[("default", "test_flow")] = {
"interfaces": {
"agent": {
"request": "agent_request_queue",
@ -630,7 +630,7 @@ class TestDispatcherManager:
mock_rr_dispatchers.__contains__.return_value = True
results = await asyncio.gather(*[
manager.invoke_flow_service("data", "responder", "test_flow", "agent")
manager.invoke_flow_service("data", "responder", "default", "test_flow", "agent")
for _ in range(5)
])
@ -638,5 +638,5 @@ class TestDispatcherManager:
"Dispatcher class instantiated more than once — duplicate consumer bug"
)
assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher
assert manager.dispatchers[("default", "test_flow", "agent")] is mock_dispatcher
assert all(r == "result" for r in results)

View file

@ -0,0 +1,75 @@
"""Tests for Gateway i18n pack endpoint."""
import json
from unittest.mock import MagicMock
import pytest
from aiohttp import web
from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint
class TestI18nPackEndpoint:
def test_i18n_endpoint_initialization(self):
mock_auth = MagicMock()
endpoint = I18nPackEndpoint(
endpoint_path="/api/v1/i18n/packs/{lang}",
auth=mock_auth,
)
assert endpoint.path == "/api/v1/i18n/packs/{lang}"
assert endpoint.auth == mock_auth
assert endpoint.operation == "service"
@pytest.mark.asyncio
async def test_i18n_endpoint_start_method(self):
mock_auth = MagicMock()
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
await endpoint.start()
def test_add_routes_registers_get_handler(self):
mock_auth = MagicMock()
mock_app = MagicMock()
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
endpoint.add_routes(mock_app)
mock_app.add_routes.assert_called_once()
call_args = mock_app.add_routes.call_args[0][0]
assert len(call_args) == 1
@pytest.mark.asyncio
async def test_handle_unauthorized_on_invalid_auth_scheme(self):
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
request = MagicMock()
request.path = "/api/v1/i18n/packs/en"
request.headers = {"Authorization": "Token abc"}
request.match_info = {"lang": "en"}
resp = await endpoint.handle(request)
assert isinstance(resp, web.HTTPUnauthorized)
@pytest.mark.asyncio
async def test_handle_returns_pack_when_permitted(self):
mock_auth = MagicMock()
mock_auth.permitted.return_value = True
endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth)
request = MagicMock()
request.path = "/api/v1/i18n/packs/en"
request.headers = {}
request.match_info = {"lang": "en"}
resp = await endpoint.handle(request)
assert resp.status == 200
payload = json.loads(resp.body.decode("utf-8"))
assert isinstance(payload, dict)
assert "cli.verify_system_status.title" in payload

View file

@ -0,0 +1,241 @@
"""
Unit tests for entity contexts import dispatcher.
Tests the business logic of EntityContextsImport while mocking the
Publisher and websocket components.
Regression coverage: a previous version constructed Metadata(metadata=...)
which raised TypeError at runtime as soon as a message was received. These
tests exercise receive() end-to-end so any future schema/kwarg drift in
the Metadata or EntityContexts construction is caught immediately.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
from trustgraph.schema import EntityContexts, EntityContext, Metadata
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def mock_running():
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_message():
"""Sample entity-contexts websocket message."""
return {
"metadata": {
"id": "doc-123",
"user": "testuser",
"collection": "testcollection",
},
"entities": [
{
"entity": {"v": "http://example.org/alice", "e": True},
"context": "Alice is a person.",
},
{
"entity": {"v": "http://example.org/bob", "e": True},
"context": "Bob is a person.",
},
],
}
@pytest.fixture
def empty_entities_message():
return {
"metadata": {
"id": "doc-empty",
"user": "u",
"collection": "c",
},
"entities": [],
}
class TestEntityContextsImportInitialization:
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
def test_init_creates_publisher_with_correct_params(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="ec-queue",
)
mock_publisher_class.assert_called_once_with(
mock_backend,
topic="ec-queue",
schema=EntityContexts,
)
assert dispatcher.ws is mock_websocket
assert dispatcher.running is mock_running
assert dispatcher.publisher is instance
class TestEntityContextsImportLifecycle:
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.start = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.start()
instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(
self, mock_publisher_class, mock_backend, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=None, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
class TestEntityContextsImportMessageProcessing:
"""Regression coverage for receive(): catches Metadata/schema drift."""
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_constructs_entity_contexts_correctly(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
# If Metadata or EntityContexts gain/lose kwargs, this raises
# TypeError — exactly the regression we want to catch.
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
call_args = instance.send.call_args
assert call_args[0][0] is None
sent = call_args[0][1]
assert isinstance(sent, EntityContexts)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2
assert all(isinstance(e, EntityContext) for e in sent.entities)
assert sent.entities[0].context == "Alice is a person."
assert sent.entities[1].context == "Bob is a person."
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_entities(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, empty_entities_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = empty_entities_message
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
sent = instance.send.call_args[0][1]
assert isinstance(sent, EntityContexts)
assert sent.entities == []
assert sent.metadata.id == "doc-empty"
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
with pytest.raises(RuntimeError, match="publish failed"):
await dispatcher.receive(mock_msg)

View file

@ -158,7 +158,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
message_type="explain",
content="",
explain_id="urn:trustgraph:agent:session:abc123",
explain_graph="urn:graph:retrieval",
@ -179,7 +179,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="thought",
message_type="thought",
content="I need to think...",
)
@ -190,7 +190,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
message_type="explain",
explain_id="urn:trustgraph:agent:session:abc123",
explain_triples=sample_triples(),
end_of_dialog=False,
@ -203,7 +203,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="answer",
message_type="answer",
content="The answer is...",
end_of_dialog=True,
)

View file

@ -0,0 +1,246 @@
"""
Unit tests for graph embeddings import dispatcher.
Tests the business logic of GraphEmbeddingsImport while mocking the
Publisher and websocket components.
Regression coverage: a previous version of EntityContextsImport
constructed Metadata(metadata=...) which raised TypeError at runtime as
soon as a message was received. The same shape of bug can occur here, so
these tests exercise receive() end-to-end to catch any future schema or
kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def mock_running():
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_message():
"""Sample graph-embeddings websocket message."""
return {
"metadata": {
"id": "doc-123",
"user": "testuser",
"collection": "testcollection",
},
"entities": [
{
"entity": {"v": "http://example.org/alice", "e": True},
"vector": [0.1, 0.2, 0.3],
},
{
"entity": {"v": "http://example.org/bob", "e": True},
"vector": [0.4, 0.5, 0.6],
},
],
}
@pytest.fixture
def empty_entities_message():
return {
"metadata": {
"id": "doc-empty",
"user": "u",
"collection": "c",
},
"entities": [],
}
class TestGraphEmbeddingsImportInitialization:
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
def test_init_creates_publisher_with_correct_params(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="ge-queue",
)
mock_publisher_class.assert_called_once_with(
mock_backend,
topic="ge-queue",
schema=GraphEmbeddings,
)
assert dispatcher.ws is mock_websocket
assert dispatcher.running is mock_running
assert dispatcher.publisher is instance
class TestGraphEmbeddingsImportLifecycle:
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.start = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.start()
instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(
self, mock_publisher_class, mock_backend, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=None, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
class TestGraphEmbeddingsImportMessageProcessing:
"""Regression coverage for receive(): catches Metadata/schema drift."""
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_constructs_graph_embeddings_correctly(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
# If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
# kwargs, this raises TypeError — exactly the regression we want
# to catch.
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
call_args = instance.send.call_args
assert call_args[0][0] is None
sent = call_args[0][1]
assert isinstance(sent, GraphEmbeddings)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2
assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
# Lock in the wire format: incoming "vector" key (singular,
# list[float]) maps to EntityEmbeddings.vector. This mirrors
# serialize_graph_embeddings() on the export side.
assert sent.entities[0].vector == [0.1, 0.2, 0.3]
assert sent.entities[1].vector == [0.4, 0.5, 0.6]
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_entities(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, empty_entities_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = empty_entities_message
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
sent = instance.send.call_args[0][1]
assert isinstance(sent, GraphEmbeddings)
assert sent.entities == []
assert sent.metadata.id == "doc-empty"
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
with pytest.raises(RuntimeError, match="publish failed"):
await dispatcher.receive(mock_msg)

View file

@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing:
# Check metadata
assert sent_object.metadata.id == "obj-123"
assert sent_object.metadata.user == "testuser"
assert sent_object.metadata.collection == "testcollection"
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')

View file

@ -171,6 +171,14 @@ class TestApi:
patch('aiohttp.web.run_app') as mock_run_app:
mock_get_pubsub.return_value = Mock()
# Api.run() passes self.app_factory() — a coroutine — to
# web.run_app, which would normally consume it inside its own
# event loop. Since we mock run_app, close the coroutine here
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
def _consume_coro(coro, **kwargs):
coro.close()
mock_run_app.side_effect = _consume_coro
api = Api(port=8080)
api.run()

View file

@ -23,7 +23,6 @@ class TestTextDocumentTranslator:
)
assert msg.metadata.id == "doc-1"
assert msg.metadata.user == "alice"
assert msg.metadata.collection == "research"
assert msg.text == payload.encode("utf-8")

View file

@ -29,10 +29,9 @@ class Triple:
self.o = o
class Metadata:
def __init__(self, id, user, collection, root=""):
def __init__(self, id, collection, root=""):
self.id = id
self.root = root
self.user = user
self.collection = collection
class Triples:
@ -108,7 +107,6 @@ def sample_triples(sample_triple):
"""Sample Triples batch object"""
metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -123,7 +121,6 @@ def sample_chunk():
"""Sample text chunk for processing"""
metadata = Metadata(
id="test-chunk-456",
user="test_user",
collection="test_collection",
)

View file

@ -322,7 +322,6 @@ This is not JSON at all
assert isinstance(sent_triples, Triples)
# Check metadata fields individually since implementation creates new Metadata object
assert sent_triples.metadata.id == sample_metadata.id
assert sent_triples.metadata.user == sample_metadata.user
assert sent_triples.metadata.collection == sample_metadata.collection
assert len(sent_triples.triples) == 1
assert sent_triples.triples[0].s.iri == "test:subject"
@ -346,7 +345,6 @@ This is not JSON at all
assert isinstance(sent_contexts, EntityContexts)
# Check metadata fields individually since implementation creates new Metadata object
assert sent_contexts.metadata.id == sample_metadata.id
assert sent_contexts.metadata.user == sample_metadata.user
assert sent_contexts.metadata.collection == sample_metadata.collection
assert len(sent_contexts.entities) == 1
assert sent_contexts.entities[0].entity.iri == "test:entity"

View file

@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic:
"""Test ExtractedObject creation and properties"""
# Arrange
metadata = Metadata(
id="test-extraction-001",
user="test_user",
id="test-extraction-001",
collection="test_collection",
)
@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic:
assert extracted_obj.values[0]["customer_id"] == "CUST001"
assert extracted_obj.confidence == 0.95
assert "John Doe" in extracted_obj.source_span
assert extracted_obj.metadata.user == "test_user"
def test_config_parsing_error_handling(self):
"""Test configuration parsing with invalid JSON"""

View file

@ -371,7 +371,6 @@ class TestTripleConstructionLogic:
metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -384,7 +383,6 @@ class TestTripleConstructionLogic:
# Assert
assert isinstance(triples_batch, Triples)
assert triples_batch.metadata.id == "test-doc-123"
assert triples_batch.metadata.user == "test_user"
assert triples_batch.metadata.collection == "test_collection"
assert len(triples_batch.triples) == 2

View file

@ -0,0 +1,119 @@
import asyncio
import io
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from uuid import uuid4
from minio.error import S3Error
from trustgraph.librarian.blob_store import BlobStore
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_blob_store():
"""Create a BlobStore with mocked Minio client."""
mock_minio = MagicMock()
with patch('trustgraph.librarian.blob_store.Minio', return_value=mock_minio):
# Prevent ensure_bucket from making network calls during init
with patch('trustgraph.librarian.blob_store.BlobStore.ensure_bucket'):
store = BlobStore(
endpoint="localhost:9000",
access_key="access",
secret_key="secret",
bucket_name="test-bucket"
)
return store, mock_minio
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_add_success_no_retry():
store, mock_minio = _make_blob_store()
object_id = uuid4()
await store.add(object_id, b"data", "text/plain")
mock_minio.put_object.assert_called_once()
@pytest.mark.asyncio
async def test_retry_recovery_on_transient_failure():
store, mock_minio = _make_blob_store()
store.base_delay = 0 # Disable delay for fast tests
# Fail twice, succeed third time
mock_minio.put_object.side_effect = [
Exception("Error 1"),
Exception("Error 2"),
MagicMock()
]
await store.add(uuid4(), b"data", "text/plain")
assert mock_minio.put_object.call_count == 3
@pytest.mark.asyncio
async def test_retry_exhaustion_after_8_attempts():
store, mock_minio = _make_blob_store()
store.base_delay = 0
# Permanent failure
mock_minio.put_object.side_effect = Exception("Permanent failure")
with pytest.raises(Exception, match="Permanent failure"):
await store.add(uuid4(), b"data", "text/plain")
# Author requirement: exactly 8 attempts
assert mock_minio.put_object.call_count == 8
@pytest.mark.asyncio
async def test_s3_error_triggers_retry():
store, mock_minio = _make_blob_store()
store.base_delay = 0
# Mock S3Error
s3_err = S3Error("code", "msg", "res", "req", "host", None)
mock_minio.get_object.side_effect = [s3_err, MagicMock()]
await store.get(uuid4())
assert mock_minio.get_object.call_count == 2
@pytest.mark.asyncio
async def test_exponential_backoff_delays():
store, mock_minio = _make_blob_store()
# Use real base_delay to check math
store.base_delay = 0.25
# Correct method name is stat_object, not get_size
mock_minio.stat_object = MagicMock(side_effect=Exception("Wait"))
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
with pytest.raises(Exception):
await store.get_size(uuid4())
# Should have 7 sleep calls for 8 attempts
assert mock_sleep.call_count == 7
# Check actual sleep durations: 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0
sleep_args = [call[0][0] for call in mock_sleep.call_args_list]
assert sleep_args == [0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
@pytest.mark.asyncio
async def test_runs_in_executor():
"""Verify that synchronous Minio calls are offloaded to an executor."""
store, mock_minio = _make_blob_store()
# Mock response object with .read() method
mock_response = MagicMock()
mock_response.read.return_value = b"result"
with patch('asyncio.get_event_loop') as mock_loop:
mock_loop_instance = MagicMock()
mock_loop.return_value = mock_loop_instance
mock_loop_instance.run_in_executor = AsyncMock(return_value=mock_response)
await store.get(uuid4())
mock_loop_instance.run_in_executor.assert_called_once()

View file

@ -22,6 +22,10 @@ def _make_librarian(min_chunk_size=1):
"""Create a Librarian with mocked blob_store and table_store."""
lib = Librarian.__new__(Librarian)
lib.blob_store = MagicMock()
lib.blob_store.create_multipart_upload = AsyncMock()
lib.blob_store.upload_part = AsyncMock()
lib.blob_store.complete_multipart_upload = AsyncMock()
lib.blob_store.abort_multipart_upload = AsyncMock()
lib.table_store = AsyncMock()
lib.load_document = AsyncMock()
lib.min_chunk_size = min_chunk_size
@ -29,12 +33,12 @@ def _make_librarian(min_chunk_size=1):
def _make_doc_metadata(
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc"
):
meta = MagicMock()
meta.id = doc_id
meta.kind = kind
meta.user = user
meta.workspace = workspace
meta.title = title
meta.time = 1700000000
meta.comments = ""
@ -43,27 +47,27 @@ def _make_doc_metadata(
def _make_begin_request(
doc_id="doc-1", kind="application/pdf", user="alice",
doc_id="doc-1", kind="application/pdf", workspace="alice",
total_size=10_000_000, chunk_size=0
):
req = MagicMock()
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace)
req.total_size = total_size
req.chunk_size = chunk_size
return req
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"):
req = MagicMock()
req.upload_id = upload_id
req.chunk_index = chunk_index
req.user = user
req.workspace = workspace
req.content = base64.b64encode(content)
return req
def _make_session(
user="alice", total_chunks=5, chunk_size=2_000_000,
workspace="alice", total_chunks=5, chunk_size=2_000_000,
total_size=10_000_000, chunks_received=None, object_id="obj-1",
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
):
@ -72,11 +76,11 @@ def _make_session(
if document_metadata is None:
document_metadata = json.dumps({
"id": document_id, "kind": "application/pdf",
"user": user, "title": "Test", "time": 1700000000,
"workspace": workspace, "title": "Test", "time": 1700000000,
"comments": "", "tags": [],
})
return {
"user": user,
"workspace": workspace,
"total_chunks": total_chunks,
"chunk_size": chunk_size,
"total_size": total_size,
@ -255,10 +259,10 @@ class TestUploadChunk:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(user="bob")
req = _make_upload_chunk_request(workspace="bob")
with pytest.raises(RequestError, match="Not authorized"):
await lib.upload_chunk(req)
@ -349,7 +353,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.complete_upload(req)
@ -371,7 +375,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
await lib.complete_upload(req)
@ -390,7 +394,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="Missing chunks"):
await lib.complete_upload(req)
@ -402,7 +406,7 @@ class TestCompleteUpload:
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.complete_upload(req)
@ -410,12 +414,12 @@ class TestCompleteUpload:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.complete_upload(req)
@ -435,7 +439,7 @@ class TestAbortUpload:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.abort_upload(req)
@ -452,7 +456,7 @@ class TestAbortUpload:
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
req.workspace = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.abort_upload(req)
@ -460,12 +464,12 @@ class TestAbortUpload:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.abort_upload(req)
@ -488,7 +492,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -506,7 +510,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-expired"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -523,7 +527,7 @@ class TestGetUploadStatus:
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
req.workspace = "alice"
resp = await lib.get_upload_status(req)
@ -535,12 +539,12 @@ class TestGetUploadStatus:
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.get_upload_status(req)
@ -560,7 +564,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -583,7 +587,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -604,7 +608,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
@ -626,7 +630,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x")
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 0 # Should use default 1MB
@ -645,7 +649,7 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=raw)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 1000
@ -662,7 +666,7 @@ class TestStreamDocument:
lib.blob_store.get_size = AsyncMock(return_value=5000)
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
req.document_id = "doc-1"
req.chunk_size = 512
@ -694,7 +698,7 @@ class TestListUploads:
]
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
resp = await lib.list_uploads(req)
@ -709,7 +713,7 @@ class TestListUploads:
lib.table_store.list_upload_sessions.return_value = []
req = MagicMock()
req.user = "alice"
req.workspace = "alice"
resp = await lib.list_uploads(req)

View file

@ -0,0 +1,590 @@
"""
DAG structure tests for provenance chains.
Verifies that the wasDerivedFrom chain has the expected shape for each
service. These tests catch structural regressions when new entities are
inserted into the chain (e.g. PatternDecision between session and first
iteration).
Expected chains:
GraphRAG: question grounding exploration focus synthesis
DocumentRAG: question grounding exploration synthesis
Agent React: session pattern-decision iteration (observation iteration)* final
Agent Plan: session pattern-decision plan step-result(s) synthesis
Agent Super: session pattern-decision decomposition (fan-out) finding(s) synthesis
"""
import json
import uuid
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
Triple, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
TG_OBSERVATION_TYPE,
TG_PATTERN, TG_TASK_TYPE,
)
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _collect_events(events):
"""Build a dict of explain_id → {types, derived_from, triples}."""
result = {}
for ev in events:
eid = ev["explain_id"]
triples = ev["triples"]
types = {
t.o.iri for t in triples
if t.s.iri == eid and t.p.iri == RDF_TYPE
}
parents = [
t.o.iri for t in triples
if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
]
result[eid] = {
"types": types,
"derived_from": parents[0] if parents else None,
"triples": triples,
}
return result
def _find_by_type(dag, rdf_type):
"""Find all event IDs that have the given rdf:type."""
return [eid for eid, info in dag.items() if rdf_type in info["types"]]
def _assert_chain(dag, chain_types):
"""Assert that a linear wasDerivedFrom chain exists through the given types."""
for i in range(1, len(chain_types)):
parent_type = chain_types[i - 1]
child_type = chain_types[i]
parents = _find_by_type(dag, parent_type)
children = _find_by_type(dag, child_type)
assert parents, f"No entity with type {parent_type}"
assert children, f"No entity with type {child_type}"
# At least one child must derive from at least one parent
linked = False
for child_id in children:
derived = dag[child_id]["derived_from"]
if derived in parents:
linked = True
break
assert linked, (
f"No {child_type} derives from {parent_type}. "
f"Children derive from: "
f"{[dag[c]['derived_from'] for c in children]}"
)
# ---------------------------------------------------------------------------
# GraphRAG DAG structure
# ---------------------------------------------------------------------------
class TestGraphRagDagStructure:
"""Verify: question → grounding → exploration → focus → synthesis"""
@pytest.fixture
def mock_clients(self):
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
embeddings_client.embed.return_value = [[0.1, 0.2]]
graph_embeddings_client.query.return_value = [
MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
]
triples_client.query_stream.return_value = [
Triple(
s=Term(type=IRI, iri="http://example.com/e1"),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="value"),
)
]
triples_client.query.return_value = []
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
elif template_id == "kg-edge-scoring":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "score": 10} for e in edges],
)
elif template_id == "kg-edge-reasoning":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
)
elif template_id == "kg-synthesis":
return PromptResult(response_type="text", text="Answer.")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
rag = GraphRag(*mock_clients)
events = []
async def explain_cb(triples, explain_id):
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb, edge_score_limit=0,
)
dag = _collect_events(events)
assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
_assert_chain(dag, [
TG_GRAPH_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_FOCUS,
TG_SYNTHESIS,
])
# ---------------------------------------------------------------------------
# DocumentRAG DAG structure
# ---------------------------------------------------------------------------
class TestDocumentRagDagStructure:
"""Verify: question → grounding → exploration → synthesis"""
@pytest.fixture
def mock_clients(self):
from trustgraph.schema import ChunkMatch
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock(return_value="Chunk content.")
embeddings_client.embed.return_value = [[0.1, 0.2]]
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.9),
]
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
prompt_client.document_prompt.return_value = PromptResult(
response_type="text", text="Answer.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
rag = DocumentRag(*mock_clients)
events = []
async def explain_cb(triples, explain_id):
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb,
)
dag = _collect_events(events)
assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
_assert_chain(dag, [
TG_DOC_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_SYNTHESIS,
])
# ---------------------------------------------------------------------------
# Agent DAG structure — tested via service.agent_request()
# ---------------------------------------------------------------------------
def _make_processor(tools=None):
processor = MagicMock()
processor.max_iterations = 10
processor.save_answer_content = AsyncMock()
def mock_session_uri(sid):
return f"urn:trustgraph:agent:session:{sid}"
processor.provenance_session_uri.side_effect = mock_session_uri
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agents = {"default": agent}
processor.aggregator = MagicMock()
return processor
def _make_flow():
producers = {}
def factory(name):
if name not in producers:
producers[name] = AsyncMock()
return producers[name]
flow = MagicMock(side_effect=factory)
flow.workspace = "default"
return flow
def _collect_agent_events(respond_mock):
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({
"explain_id": resp.explain_id,
"triples": resp.explain_triples,
})
return events
class TestAgentReactDagStructure:
"""
Via service.agent_request(), full two-iteration react chain:
session pattern-decision iteration(1) observation(1) final
Iteration 1: tool call observation
Iteration 2: final answer
"""
def _make_service(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
mock_tool = MagicMock()
mock_tool.name = "lookup"
mock_tool.description = "Look things up"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="42")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = _make_processor(tools={"lookup": mock_tool})
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
return service
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.react.types import Action, Final
service = self._make_service()
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
session_id = str(uuid.uuid4())
# Iteration 1: tool call → returns Action, triggers on_action + tool exec
action = Action(
thought="I need to look this up",
name="lookup",
arguments={"question": "6x7"},
observation="",
)
with patch(
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react_iter1(on_action=None, **kwargs):
if on_action:
await on_action(action)
action.observation = "42"
return action
mock_am.react.side_effect = mock_react_iter1
request1 = AgentRequest(
question="What is 6x7?",
collection="default",
streaming=False,
session_id=session_id,
pattern="react",
history=[],
)
await service.agent_request(request1, respond, next_fn, flow)
# next_fn should have been called with updated history
assert next_fn.called
# Iteration 2: final answer
final = Final(thought="The answer is 42", final="42")
next_request = next_fn.call_args[0][0]
with patch(
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react_iter2(**kwargs):
return final
mock_am.react.side_effect = mock_react_iter2
await service.agent_request(next_request, respond, next_fn, flow)
# Collect and verify DAG
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
analysis_ids = _find_by_type(dag, TG_ANALYSIS)
observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
final_ids = _find_by_type(dag, TG_CONCLUSION)
assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
# Full chain:
# session → pattern-decision
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
# pattern-decision → iteration(1)
assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
# iteration(1) → observation(1)
assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
# observation(1) → final
assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
class TestAgentPlanDagStructure:
"""
Via service.agent_request():
session pattern-decision plan step-result synthesis
"""
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
# Mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query KB"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="Found it")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = _make_processor(tools={"knowledge-query": mock_tool})
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
# Mock prompt client
mock_prompt_client = AsyncMock()
call_count = 0
async def mock_prompt(id, variables=None, **kwargs):
nonlocal call_count
call_count += 1
if id == "plan-create":
return PromptResult(
response_type="jsonl",
objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
)
elif id == "plan-step-execute":
return PromptResult(
response_type="json",
object={"tool": "knowledge-query", "arguments": {"question": "test"}},
)
elif id == "plan-synthesise":
return PromptResult(response_type="text", text="Final answer.")
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt.side_effect = mock_prompt
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
session_id = str(uuid.uuid4())
# Iteration 1: planning
request1 = AgentRequest(
question="Test?",
collection="default",
streaming=False,
session_id=session_id,
pattern="plan-then-execute",
history=[],
)
await service.agent_request(request1, respond, next_fn, flow)
# Iteration 2: execute step (next_fn was called with updated request)
assert next_fn.called
next_request = next_fn.call_args[0][0]
# Iteration 3: all steps done → synthesis
# Simulate completed step in history
next_request.history[-1].plan[0].status = "completed"
next_request.history[-1].plan[0].result = "Found it"
await service.agent_request(next_request, respond, next_fn, flow)
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
assert len(session_ids) == 1
assert len(pd_ids) == 1
assert len(plan_ids) == 1
assert len(synthesis_ids) == 1
# Chain: session → pattern-decision → plan → ... → synthesis
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
class TestAgentSupervisorDagStructure:
"""
Via service.agent_request():
session pattern-decision decomposition (fan-out)
"""
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = _make_processor()
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agents = processor.agents
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B"],
)
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = AgentRequest(
question="Research quantum computing",
collection="default",
streaming=False,
session_id=str(uuid.uuid4()),
pattern="supervisor",
history=[],
)
await service.agent_request(request, respond, next_fn, flow)
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
assert len(session_ids) == 1
assert len(pd_ids) == 1
assert len(decomp_ids) == 1
# Chain: session → pattern-decision → decomposition
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
# Fan-out should have been called
assert next_fn.call_count == 2 # One per goal

View file

@ -223,7 +223,7 @@ class TestDerivedEntityTriples:
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
def test_chunk_entity_has_chunk_type(self):
def test_chunk_entity_has_message_type(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"chunker", "1.0",

View file

@ -0,0 +1,131 @@
"""
Unit tests for Kafka backend topic parsing and factory dispatch.
Does not require a running Kafka instance.
"""
import pytest
import argparse
from trustgraph.base.kafka_backend import KafkaBackend
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
class TestKafkaParseTopic:
@pytest.fixture
def backend(self):
b = object.__new__(KafkaBackend)
return b
def test_flow_is_durable(self, backend):
name, cls, durable = backend._parse_topic('flow:tg:text-completion-request')
assert durable is True
assert cls == 'flow'
assert name == 'tg.flow.text-completion-request'
def test_notify_is_not_durable(self, backend):
name, cls, durable = backend._parse_topic('notify:tg:config')
assert durable is False
assert cls == 'notify'
assert name == 'tg.notify.config'
def test_request_is_not_durable(self, backend):
name, cls, durable = backend._parse_topic('request:tg:config')
assert durable is False
assert cls == 'request'
assert name == 'tg.request.config'
def test_response_is_not_durable(self, backend):
name, cls, durable = backend._parse_topic('response:tg:librarian')
assert durable is False
assert cls == 'response'
assert name == 'tg.response.librarian'
def test_custom_topicspace(self, backend):
name, cls, durable = backend._parse_topic('flow:prod:my-queue')
assert name == 'prod.flow.my-queue'
assert durable is True
def test_no_colon_defaults_to_flow(self, backend):
name, cls, durable = backend._parse_topic('simple-queue')
assert name == 'tg.flow.simple-queue'
assert cls == 'flow'
assert durable is True
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid topic class"):
backend._parse_topic('unknown:tg:topic')
def test_topic_with_flow_suffix(self, backend):
"""Topic names with flow suffix (e.g. :default) have colons replaced with dots."""
name, cls, durable = backend._parse_topic('request:tg:prompt:default')
assert name == 'tg.request.prompt.default'
class TestKafkaRetention:
@pytest.fixture
def backend(self):
b = object.__new__(KafkaBackend)
return b
def test_flow_gets_long_retention(self, backend):
assert backend._retention_ms('flow') == 7 * 24 * 60 * 60 * 1000
def test_request_gets_short_retention(self, backend):
assert backend._retention_ms('request') == 300 * 1000
def test_response_gets_short_retention(self, backend):
assert backend._retention_ms('response') == 300 * 1000
def test_notify_gets_short_retention(self, backend):
assert backend._retention_ms('notify') == 300 * 1000
class TestGetPubsubKafka:
def test_factory_creates_kafka_backend(self):
backend = get_pubsub(pubsub_backend='kafka')
assert isinstance(backend, KafkaBackend)
def test_factory_passes_config(self):
backend = get_pubsub(
pubsub_backend='kafka',
kafka_bootstrap_servers='myhost:9093',
kafka_security_protocol='SASL_SSL',
kafka_sasl_mechanism='PLAIN',
kafka_sasl_username='user',
kafka_sasl_password='pass',
)
assert isinstance(backend, KafkaBackend)
assert backend._bootstrap_servers == 'myhost:9093'
assert backend._admin_config['security.protocol'] == 'SASL_SSL'
assert backend._admin_config['sasl.mechanism'] == 'PLAIN'
assert backend._admin_config['sasl.username'] == 'user'
assert backend._admin_config['sasl.password'] == 'pass'
class TestAddPubsubArgsKafka:
def test_kafka_args_present(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([
'--pubsub-backend', 'kafka',
'--kafka-bootstrap-servers', 'myhost:9093',
])
assert args.pubsub_backend == 'kafka'
assert args.kafka_bootstrap_servers == 'myhost:9093'
def test_kafka_defaults_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.kafka_bootstrap_servers == 'kafka:9092'
assert args.kafka_security_protocol == 'PLAINTEXT'
def test_kafka_standalone_defaults_to_localhost(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args([])
assert args.kafka_bootstrap_servers == 'localhost:9092'

View file

@ -1,18 +1,16 @@
"""
Unit tests for RabbitMQ backend queue name mapping and factory dispatch.
Unit tests for RabbitMQ backend topic parsing and factory dispatch.
Does not require a running RabbitMQ instance.
"""
import pytest
import argparse
pika = pytest.importorskip("pika", reason="pika not installed")
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
class TestRabbitMQMapQueueName:
class TestRabbitMQParseTopic:
@pytest.fixture
def backend(self):
@ -20,43 +18,48 @@ class TestRabbitMQMapQueueName:
return b
def test_flow_is_durable(self, backend):
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
exchange, cls, durable = backend._parse_topic('flow:tg:text-completion-request')
assert durable is True
assert name == 'tg.flow.text-completion-request'
assert cls == 'flow'
assert exchange == 'tg.flow.text-completion-request'
def test_notify_is_not_durable(self, backend):
name, durable = backend.map_queue_name('notify:tg:config')
exchange, cls, durable = backend._parse_topic('notify:tg:config')
assert durable is False
assert name == 'tg.notify.config'
assert cls == 'notify'
assert exchange == 'tg.notify.config'
def test_request_is_not_durable(self, backend):
name, durable = backend.map_queue_name('request:tg:config')
exchange, cls, durable = backend._parse_topic('request:tg:config')
assert durable is False
assert name == 'tg.request.config'
assert cls == 'request'
assert exchange == 'tg.request.config'
def test_response_is_not_durable(self, backend):
name, durable = backend.map_queue_name('response:tg:librarian')
exchange, cls, durable = backend._parse_topic('response:tg:librarian')
assert durable is False
assert name == 'tg.response.librarian'
assert cls == 'response'
assert exchange == 'tg.response.librarian'
def test_custom_topicspace(self, backend):
name, durable = backend.map_queue_name('flow:prod:my-queue')
assert name == 'prod.flow.my-queue'
exchange, cls, durable = backend._parse_topic('flow:prod:my-queue')
assert exchange == 'prod.flow.my-queue'
assert durable is True
def test_no_colon_defaults_to_flow(self, backend):
name, durable = backend.map_queue_name('simple-queue')
assert name == 'tg.simple-queue'
assert durable is False
exchange, cls, durable = backend._parse_topic('simple-queue')
assert exchange == 'tg.flow.simple-queue'
assert cls == 'flow'
assert durable is True
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue class"):
backend.map_queue_name('unknown:tg:topic')
with pytest.raises(ValueError, match="Invalid topic class"):
backend._parse_topic('unknown:tg:topic')
def test_flow_with_flow_suffix(self, backend):
"""Queue names with flow suffix (e.g. :default) are preserved."""
name, durable = backend.map_queue_name('request:tg:prompt:default')
assert name == 'tg.request.prompt:default'
def test_topic_with_flow_suffix(self, backend):
"""Topic names with flow suffix (e.g. :default) are preserved."""
exchange, cls, durable = backend._parse_topic('request:tg:prompt:default')
assert exchange == 'tg.request.prompt:default'
class TestGetPubsubRabbitMQ:

View file

@ -304,14 +304,14 @@ class TestStreamingTypes:
assert chunk.content == "thinking..."
assert chunk.end_of_message is False
assert chunk.chunk_type == "thought"
assert chunk.message_type == "thought"
def test_agent_observation_creation(self):
"""Test creating AgentObservation chunk"""
chunk = AgentObservation(content="observing...", end_of_message=False)
assert chunk.content == "observing..."
assert chunk.chunk_type == "observation"
assert chunk.message_type == "observation"
def test_agent_answer_creation(self):
"""Test creating AgentAnswer chunk"""
@ -324,7 +324,7 @@ class TestStreamingTypes:
assert chunk.content == "answer"
assert chunk.end_of_message is True
assert chunk.end_of_dialog is True
assert chunk.chunk_type == "final-answer"
assert chunk.message_type == "final-answer"
def test_rag_chunk_creation(self):
"""Test creating RAGChunk"""

View file

@ -31,7 +31,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
def mock_query_request(self):
"""Create a mock query request for testing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
@ -69,7 +68,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -83,7 +81,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
@ -101,7 +99,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_longer_vector(self, processor):
"""Test querying document embeddings with a longer vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3
@ -115,7 +112,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
@ -133,7 +130,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_with_limit(self, processor):
"""Test querying document embeddings respects limit parameter"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=2
@ -148,7 +144,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with the specified limit
processor.vecstore.search.assert_called_once_with(
@ -162,13 +158,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_vectors(self, processor):
"""Test querying document embeddings with empty vectors list"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[],
limit=5
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
@ -180,7 +175,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_search_results(self, processor):
"""Test querying document embeddings with empty search results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -189,7 +183,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called
processor.vecstore.search.assert_called_once_with(
@ -203,7 +197,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_unicode_documents(self, processor):
"""Test querying document embeddings with Unicode document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -217,7 +210,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3
@ -230,7 +223,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_large_documents(self, processor):
"""Test querying document embeddings with large document content"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -244,7 +236,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify large content is preserved in ChunkMatch objects
assert len(result) == 2
@ -256,7 +248,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_special_characters(self, processor):
"""Test querying document embeddings with special characters in documents"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -270,7 +261,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3
@ -283,13 +274,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying document embeddings with zero limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=0
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
@ -301,13 +291,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying document embeddings with negative limit"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=-1
)
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify no search was called (optimization for negative limit)
processor.vecstore.search.assert_not_called()
@ -319,7 +308,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -330,13 +318,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_document_embeddings(query)
await processor.query_document_embeddings('test_user', query)
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying document embeddings with different vector dimensions"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
limit=5
@ -349,7 +336,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify search was called with the vector
processor.vecstore.search.assert_called_once()
@ -364,7 +351,6 @@ class TestMilvusDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_multiple_results(self, processor):
"""Test querying document embeddings with multiple results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -378,7 +364,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
result = await processor.query_document_embeddings('test_user', query)
# Verify results are ChunkMatch objects
assert len(result) == 3

View file

@ -117,7 +117,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -160,7 +160,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results1, mock_results2]
chunks = await processor.query_document_embeddings(mock_query_message)
chunks = await processor.query_document_embeddings('default', mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
@ -191,7 +191,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify limit is passed to query
mock_index.query.assert_called_once()
@ -213,7 +213,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -231,7 +231,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -259,7 +259,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
@ -287,7 +287,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
@ -310,7 +310,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_results.matches = []
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify empty results
assert chunks == []
@ -334,7 +334,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify Unicode content is properly handled
assert len(chunks) == 2
@ -361,7 +361,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify large content is properly handled
assert len(chunks) == 1
@ -389,7 +389,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify all content types are properly handled
assert len(chunks) == 5
@ -413,7 +413,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_document_embeddings(message)
await processor.query_document_embeddings('test_user', message)
@pytest.mark.asyncio
async def test_query_document_embeddings_index_access_failure(self, processor):
@ -427,7 +427,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
processor.pinecone.Index.side_effect = Exception("Index access failed")
with pytest.raises(Exception, match="Index access failed"):
await processor.query_document_embeddings(message)
await processor.query_document_embeddings('test_user', message)
@pytest.mark.asyncio
async def test_query_document_embeddings_vector_accumulation(self, processor):
@ -462,7 +462,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
mock_index.query.side_effect = [mock_results1, mock_results2, mock_results3]
chunks = await processor.query_document_embeddings(message)
chunks = await processor.query_document_embeddings('test_user', message)
# Verify all queries were made
assert mock_index.query.call_count == 3

View file

@ -101,7 +101,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'test_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('test_user', mock_message)
# Assert
# Verify query was called with correct parameters (with dimension suffix)
@ -158,7 +158,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('multi_user', mock_message)
# Assert
# Verify query was called once
@ -212,7 +212,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('limit_user', mock_message)
# Assert
# Verify query was called with exact limit (no multiplication)
@ -252,7 +252,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('empty_user', mock_message)
# Assert
assert result == []
@ -291,7 +291,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('dim_user', mock_message)
# Assert
# Verify query was called once with correct collection
@ -342,7 +342,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'utf8_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('utf8_user', mock_message)
# Assert
assert len(result) == 2
@ -380,7 +380,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_document_embeddings(mock_message)
await processor.query_document_embeddings('error_user', mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -413,7 +413,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('zero_user', mock_message)
# Assert
# Should still query (with limit 0)
@ -460,7 +460,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'large_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
result = await processor.query_document_embeddings('large_user', mock_message)
# Assert
# Should query with full limit
@ -512,7 +512,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
# This should raise a KeyError when trying to access payload['chunk_id']
with pytest.raises(KeyError):
await processor.query_document_embeddings(mock_message)
await processor.query_document_embeddings('payload_user', mock_message)
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')

View file

@ -31,7 +31,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
def mock_query_request(self):
"""Create a mock query request for testing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
@ -117,7 +116,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -131,7 +129,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with correct parameters including user/collection
processor.vecstore.search.assert_called_once_with(
@ -154,7 +152,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_multiple_results(self, processor):
"""Test querying graph embeddings returns multiple results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -168,7 +165,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
@ -186,7 +183,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_with_limit(self, processor):
"""Test querying graph embeddings respects limit parameter"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=2
@ -201,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with 2*limit for better deduplication
processor.vecstore.search.assert_called_once_with(
@ -215,7 +211,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_preserves_order(self, processor):
"""Test that query results preserve order from the vector store"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
@ -229,7 +224,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify results are in the same order as returned by the store
assert len(result) == 3
@ -241,7 +236,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_results_limited(self, processor):
"""Test that results are properly limited when store returns more than requested"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2
@ -255,7 +249,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
@ -269,13 +263,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_empty_vectors(self, processor):
"""Test querying graph embeddings with empty vectors list"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[],
limit=5
)
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify no search was called
processor.vecstore.search.assert_not_called()
@ -287,7 +280,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_empty_search_results(self, processor):
"""Test querying graph embeddings with empty search results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -296,7 +288,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Mock empty search results
processor.vecstore.search.return_value = []
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called
processor.vecstore.search.assert_called_once_with(
@ -310,7 +302,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_mixed_uri_literal_results(self, processor):
"""Test querying graph embeddings with mixed URI and literal results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -325,7 +316,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify all results are properly typed
assert len(result) == 4
@ -348,7 +339,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test exception handling during query processing"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=5
@ -359,7 +349,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Milvus connection failed"):
await processor.query_graph_embeddings(query)
await processor.query_graph_embeddings('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
@ -430,13 +420,12 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying graph embeddings with zero limit"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3],
limit=0
)
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify no search was called (optimization for zero limit)
processor.vecstore.search.assert_not_called()
@ -448,7 +437,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_longer_vector(self, processor):
"""Test querying graph embeddings with a longer vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
limit=5
@ -461,7 +449,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
result = await processor.query_graph_embeddings('test_user', query)
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once()

View file

@ -145,7 +145,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify index was accessed correctly (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -185,7 +185,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message)
entities = await processor.query_graph_embeddings('default', mock_query_message)
# Verify query was made once
assert mock_index.query.call_count == 1
@ -216,7 +216,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify limit is respected
assert len(entities) == 2
@ -233,7 +233,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -251,7 +251,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no query was made and empty result returned
mock_index.query.assert_not_called()
@ -276,7 +276,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify correct index used for 2D vector
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
@ -300,7 +300,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify no queries were made and empty result returned
processor.pinecone.Index.assert_not_called()
@ -323,7 +323,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_results.matches = []
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Verify empty results
assert entities == []
@ -352,7 +352,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
@ -380,7 +380,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
entities = await processor.query_graph_embeddings('test_user', message)
# Should only return 2 entities (respecting limit)
mock_index.query.assert_called_once()
@ -400,7 +400,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
mock_index.query.side_effect = Exception("Query failed")
with pytest.raises(Exception, match="Query failed"):
await processor.query_graph_embeddings(message)
await processor.query_graph_embeddings('test_user', message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -173,7 +173,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'test_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('test_user', mock_message)
# Assert
# Verify query was called with correct parameters (with dimension suffix)
@ -230,7 +230,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('multi_user', mock_message)
# Assert
# Verify query was called once
@ -283,7 +283,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'limit_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('limit_user', mock_message)
# Assert
# Verify query was called with limit * 2
@ -323,7 +323,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'empty_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('empty_user', mock_message)
# Assert
assert result == []
@ -364,7 +364,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('dim_user', mock_message)
# Assert
# Verify query was called once
@ -415,7 +415,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'uri_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('uri_user', mock_message)
# Assert
assert len(result) == 3
@ -460,7 +460,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
await processor.query_graph_embeddings(mock_message)
await processor.query_graph_embeddings('error_user', mock_message)
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -493,7 +493,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'zero_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
result = await processor.query_graph_embeddings('zero_user', mock_message)
# Assert
# Should still query (with limit 0)

View file

@ -9,12 +9,12 @@ from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestMemgraphQueryUserCollectionIsolation:
class TestMemgraphQueryWorkspaceCollectionIsolation:
"""Test cases for Memgraph query service with user/collection isolation"""
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -22,7 +22,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -32,13 +31,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT 1000"
)
@ -48,14 +47,14 @@ class TestMemgraphQueryUserCollectionIsolation:
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -63,7 +62,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -73,13 +71,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT 1000"
)
@ -88,14 +86,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
async def test_so_query_with_workspace_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -103,7 +101,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -113,13 +110,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT 1000"
)
@ -128,14 +125,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -143,7 +140,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -153,13 +149,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 1000"
)
@ -167,14 +163,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
async def test_po_query_with_workspace_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -182,7 +178,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -192,13 +187,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT 1000"
)
@ -207,14 +202,14 @@ class TestMemgraphQueryUserCollectionIsolation:
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -222,7 +217,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -232,13 +226,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT 1000"
)
@ -246,14 +240,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -261,7 +255,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -271,13 +264,13 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 1000"
)
@ -285,14 +278,14 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -300,7 +293,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -310,36 +302,36 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 1000"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='memgraph'
)
@ -363,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples('default', query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
@ -383,7 +375,6 @@ class TestMemgraphQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -410,7 +401,7 @@ class TestMemgraphQueryUserCollectionIsolation:
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
result = await processor.query_triples("test_user", query)
# Verify results are proper Triple objects
assert len(result) == 2

View file

@ -9,12 +9,12 @@ from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestNeo4jQueryUserCollectionIsolation:
class TestNeo4jQueryWorkspaceCollectionIsolation:
"""Test cases for Neo4j query service with user/collection isolation"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_spo_query_with_user_collection(self, mock_graph_db):
async def test_spo_query_with_workspace_collection(self, mock_graph_db):
"""Test SPO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -22,7 +22,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -32,13 +31,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT 10"
)
@ -48,14 +47,14 @@ class TestNeo4jQueryUserCollectionIsolation:
src="http://example.com/s",
rel="http://example.com/p",
value="test_object",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_sp_query_with_user_collection(self, mock_graph_db):
async def test_sp_query_with_workspace_collection(self, mock_graph_db):
"""Test SP query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -63,7 +62,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
@ -73,13 +71,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT 10"
)
@ -88,16 +86,16 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_literal_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify SP query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT 10"
)
@ -106,14 +104,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_node_query,
src="http://example.com/s",
rel="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_so_query_with_user_collection(self, mock_graph_db):
async def test_so_query_with_workspace_collection(self, mock_graph_db):
"""Test SO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -121,7 +119,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -131,13 +128,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT 10"
)
@ -146,14 +143,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_query,
src="http://example.com/s",
uri="http://example.com/o",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_s_only_query_with_user_collection(self, mock_graph_db):
async def test_s_only_query_with_workspace_collection(self, mock_graph_db):
"""Test S-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -161,7 +158,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -171,13 +167,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
@ -185,14 +181,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_po_query_with_user_collection(self, mock_graph_db):
async def test_po_query_with_workspace_collection(self, mock_graph_db):
"""Test PO query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -200,7 +196,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -210,13 +205,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT 10"
)
@ -225,14 +220,14 @@ class TestNeo4jQueryUserCollectionIsolation:
expected_query,
uri="http://example.com/p",
value="literal",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_p_only_query_with_user_collection(self, mock_graph_db):
async def test_p_only_query_with_workspace_collection(self, mock_graph_db):
"""Test P-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -240,7 +235,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=Term(type=IRI, iri="http://example.com/p"),
@ -250,13 +244,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT 10"
)
@ -264,14 +258,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
uri="http://example.com/p",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_o_only_query_with_user_collection(self, mock_graph_db):
async def test_o_only_query_with_workspace_collection(self, mock_graph_db):
"""Test O-only query pattern includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -279,7 +273,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -289,13 +282,13 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 10"
)
@ -303,14 +296,14 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.assert_any_call(
expected_query,
value="test_value",
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_with_user_collection(self, mock_graph_db):
async def test_wildcard_query_with_workspace_collection(self, mock_graph_db):
"""Test wildcard query (all None) includes user/collection filtering"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
@ -318,7 +311,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
@ -328,36 +320,36 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples("test_user", query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_node_query,
user="test_user",
workspace="test_user",
collection="test_collection",
database_='neo4j'
)
@ -381,7 +373,7 @@ class TestNeo4jQueryUserCollectionIsolation:
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
await processor.query_triples('default', query)
# Verify defaults were used
calls = mock_driver.execute_query.call_args_list
@ -401,7 +393,6 @@ class TestNeo4jQueryUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
@ -428,7 +419,7 @@ class TestNeo4jQueryUserCollectionIsolation:
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
result = await processor.query_triples("test_user", query)
# Verify results are proper Triple objects
assert len(result) == 2

View file

@ -91,11 +91,10 @@ class TestRowsGraphQLQueryLogic:
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.schema_builders = {}
processor.graphql_schemas = {}
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.query_cassandra = MagicMock()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
@ -129,11 +128,11 @@ class TestRowsGraphQLQueryLogic:
}
# Process config
await processor.on_schema_config(schema_config, version=1)
await processor.on_schema_config("default", schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert "customer" in processor.schemas["default"]
schema = processor.schemas["default"]["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
@ -147,39 +146,40 @@ class TestRowsGraphQLQueryLogic:
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
# Verify per-workspace schema builder was created and graphql schema built
assert "default" in processor.schema_builders
assert "default" in processor.graphql_schemas
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
"""Test GraphQL execution context setup"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
graphql_schema.execute.assert_called_once()
call_args = graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value']
assert context["processor"] == processor
assert context["user"] == "test_user"
assert context["workspace"] == "default"
assert context["collection"] == "test_collection"
# Verify result structure
@ -190,7 +190,8 @@ class TestRowsGraphQLQueryLogic:
async def test_error_handling_graphql_errors(self):
"""Test GraphQL error handling and conversion"""
processor = MagicMock()
processor.graphql_schema = AsyncMock()
graphql_schema = AsyncMock()
processor.graphql_schemas = {"default": graphql_schema}
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error
@ -212,13 +213,13 @@ class TestRowsGraphQLQueryLogic:
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
workspace="default",
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
@ -248,7 +249,6 @@ class TestRowsGraphQLQueryLogic:
# Create mock message
mock_msg = MagicMock()
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ customers { id name } }',
variables={},
@ -259,6 +259,7 @@ class TestRowsGraphQLQueryLogic:
# Mock flow
mock_flow = MagicMock()
mock_flow.workspace = "default"
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
@ -267,10 +268,10 @@ class TestRowsGraphQLQueryLogic:
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
workspace="default",
query='{ customers { id name } }',
variables={},
operation_name=None,
user="test_user",
collection="test_collection"
)
@ -297,7 +298,6 @@ class TestRowsGraphQLQueryLogic:
# Create mock message
mock_msg = MagicMock()
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ invalid_query }',
variables={},
@ -330,7 +330,8 @@ class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
async def test_query_with_index_match(self):
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_with_index_match(self, mock_async_execute):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
@ -340,10 +341,10 @@ class TestUnifiedTableQueries:
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
# Mock async_execute to return test data
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
processor.session.execute.return_value = [mock_row]
mock_async_execute.return_value = [mock_row]
schema = RowSchema(
name="products",
@ -356,7 +357,7 @@ class TestUnifiedTableQueries:
# Query with filter on indexed field
results = await processor.query_cassandra(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="products",
row_schema=schema,
@ -366,14 +367,14 @@ class TestUnifiedTableQueries:
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
mock_async_execute.assert_called_once()
# Verify query structure - should query unified rows table
call_args = processor.session.execute.call_args
query = call_args[0][0]
params = call_args[0][1]
call_args = mock_async_execute.call_args
query = call_args[0][1]
params = call_args[0][2]
assert "SELECT data, source FROM test_user.rows" in query
assert "SELECT data, source FROM test_workspace.rows" in query
assert "collection = %s" in query
assert "schema_name = %s" in query
assert "index_name = %s" in query
@ -390,7 +391,8 @@ class TestUnifiedTableQueries:
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
async def test_query_without_index_match(self):
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_without_index_match(self, mock_async_execute):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
@ -401,12 +403,12 @@ class TestUnifiedTableQueries:
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
# Mock async_execute to return test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
processor.session.execute.return_value = [mock_row1, mock_row2]
mock_async_execute.return_value = [mock_row1, mock_row2]
schema = RowSchema(
name="products",
@ -419,7 +421,7 @@ class TestUnifiedTableQueries:
# Query with filter on non-indexed field
results = await processor.query_cassandra(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="products",
row_schema=schema,
@ -428,8 +430,8 @@ class TestUnifiedTableQueries:
)
# Query should use ALLOW FILTERING for scan
call_args = processor.session.execute.call_args
query = call_args[0][0]
call_args = mock_async_execute.call_args
query = call_args[0][1]
assert "ALLOW FILTERING" in query

View file

@ -95,7 +95,6 @@ class TestCassandraQueryProcessor:
# Create query request with all SPO values
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -103,7 +102,7 @@ class TestCassandraQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify KnowledgeGraph was created with correct parameters
mock_kg_class.assert_called_once_with(
@ -170,7 +169,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -178,7 +176,7 @@ class TestCassandraQueryProcessor:
limit=50
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
@ -207,7 +205,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -215,7 +212,7 @@ class TestCassandraQueryProcessor:
limit=25
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
@ -244,7 +241,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=LITERAL, value='test_predicate'),
@ -252,7 +248,7 @@ class TestCassandraQueryProcessor:
limit=10
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
@ -281,7 +277,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -289,7 +284,7 @@ class TestCassandraQueryProcessor:
limit=75
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
@ -319,7 +314,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -327,7 +321,7 @@ class TestCassandraQueryProcessor:
limit=1000
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
@ -425,7 +419,6 @@ class TestCassandraQueryProcessor:
)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -433,7 +426,7 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query)
await processor.query_triples('test_user', query)
# Verify KnowledgeGraph was created with authentication
mock_kg_class.assert_called_once_with(
@ -463,7 +456,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -472,11 +464,11 @@ class TestCassandraQueryProcessor:
)
# First query should create TrustGraph
await processor.query_triples(query)
await processor.query_triples('test_user', query)
assert mock_kg_class.call_count == 1
# Second query with same table should reuse TrustGraph
await processor.query_triples(query)
await processor.query_triples('test_user', query)
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@ -504,7 +496,6 @@ class TestCassandraQueryProcessor:
# First query
query1 = TriplesQueryRequest(
user='user1',
collection='collection1',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -512,12 +503,11 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query1)
await processor.query_triples('user1', query1)
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
user='user2',
collection='collection2',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -525,7 +515,7 @@ class TestCassandraQueryProcessor:
limit=100
)
await processor.query_triples(query2)
await processor.query_triples('user2', query2)
assert processor.table == 'user2'
# Verify TrustGraph was created twice
@ -544,7 +534,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -553,7 +542,7 @@ class TestCassandraQueryProcessor:
)
with pytest.raises(Exception, match="Query failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -582,7 +571,6 @@ class TestCassandraQueryProcessor:
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
@ -590,7 +578,7 @@ class TestCassandraQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
assert len(result) == 2
assert result[0].o.value == 'object1'
@ -621,7 +609,6 @@ class TestCassandraQueryPerformanceOptimizations:
# PO query pattern (predicate + object, find subjects)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=LITERAL, value='test_predicate'),
@ -629,7 +616,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=50
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
@ -662,7 +649,6 @@ class TestCassandraQueryPerformanceOptimizations:
# OS query pattern (object + subject, find predicates)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value='test_subject'),
p=None,
@ -670,7 +656,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=25
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
@ -721,7 +707,6 @@ class TestCassandraQueryPerformanceOptimizations:
mock_tg_instance.reset_mock()
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=LITERAL, value=s) if s else None,
p=Term(type=LITERAL, value=p) if p else None,
@ -729,7 +714,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=10
)
await processor.query_triples(query)
await processor.query_triples('test_user', query)
# Verify the correct method was called
method = getattr(mock_tg_instance, expected_method)
@ -780,7 +765,6 @@ class TestCassandraQueryPerformanceOptimizations:
# This is the query pattern that was slow with ALLOW FILTERING
query = TriplesQueryRequest(
user='large_dataset_user',
collection='massive_collection',
s=None,
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
@ -788,7 +772,7 @@ class TestCassandraQueryPerformanceOptimizations:
limit=1000
)
result = await processor.query_triples(query)
result = await processor.query_triples('large_dataset_user', query)
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.get_po.assert_called_once_with(

View file

@ -123,7 +123,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -131,7 +130,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -164,7 +163,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -172,7 +170,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -209,7 +207,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -217,7 +214,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -254,7 +251,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -262,7 +258,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -299,7 +295,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -307,7 +302,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -344,7 +339,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -352,7 +346,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -389,7 +383,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -397,7 +390,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -434,7 +427,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -442,7 +434,7 @@ class TestFalkorDBQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_graph.query.call_count == 2
@ -474,7 +466,6 @@ class TestFalkorDBQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -484,7 +475,7 @@ class TestFalkorDBQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -122,7 +122,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -130,7 +129,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -164,7 +163,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -172,7 +170,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -210,7 +208,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -218,7 +215,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -256,7 +253,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -264,7 +260,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -302,7 +298,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -310,7 +305,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -348,7 +343,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -356,7 +350,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -394,7 +388,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -402,7 +395,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -440,7 +433,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -448,7 +440,7 @@ class TestMemgraphQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -478,7 +470,6 @@ class TestMemgraphQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -488,7 +479,7 @@ class TestMemgraphQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -122,7 +122,6 @@ class TestNeo4jQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -130,7 +129,7 @@ class TestNeo4jQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -164,7 +163,6 @@ class TestNeo4jQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
@ -172,7 +170,7 @@ class TestNeo4jQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -210,7 +208,6 @@ class TestNeo4jQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
@ -218,7 +215,7 @@ class TestNeo4jQueryProcessor:
limit=100
)
result = await processor.query_triples(query)
result = await processor.query_triples('test_user', query)
# Verify both literal and URI queries were executed
assert mock_driver.execute_query.call_count == 2
@ -248,7 +245,6 @@ class TestNeo4jQueryProcessor:
# Create query request
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
@ -258,7 +254,7 @@ class TestNeo4jQueryProcessor:
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)
await processor.query_triples('test_user', query)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -30,7 +30,7 @@ class TestDocumentMetadataTranslator:
"title": "Test Document",
"comments": "No comments",
"metadata": [],
"user": "alice",
"workspace": "alice",
"tags": ["finance", "q4"],
"parent-id": "doc-100",
"document-type": "page",
@ -40,14 +40,14 @@ class TestDocumentMetadataTranslator:
assert obj.time == 1710000000
assert obj.kind == "application/pdf"
assert obj.title == "Test Document"
assert obj.user == "alice"
assert obj.workspace == "alice"
assert obj.tags == ["finance", "q4"]
assert obj.parent_id == "doc-100"
assert obj.document_type == "page"
wire = self.tx.encode(obj)
assert wire["id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["workspace"] == "alice"
assert wire["parent-id"] == "doc-100"
assert wire["document-type"] == "page"
@ -80,10 +80,10 @@ class TestDocumentMetadataTranslator:
def test_falsy_fields_omitted_from_wire(self):
"""Empty string fields should be omitted from wire format."""
obj = DocumentMetadata(id="", time=0, user="")
obj = DocumentMetadata(id="", time=0, workspace="")
wire = self.tx.encode(obj)
assert "id" not in wire
assert "user" not in wire
assert "workspace" not in wire
# ---------------------------------------------------------------------------
@ -101,7 +101,7 @@ class TestProcessingMetadataTranslator:
"document-id": "doc-123",
"time": 1710000000,
"flow": "default",
"user": "alice",
"workspace": "alice",
"collection": "my-collection",
"tags": ["tag1"],
}
@ -109,20 +109,20 @@ class TestProcessingMetadataTranslator:
assert obj.id == "proc-1"
assert obj.document_id == "doc-123"
assert obj.flow == "default"
assert obj.user == "alice"
assert obj.workspace == "alice"
assert obj.collection == "my-collection"
assert obj.tags == ["tag1"]
wire = self.tx.encode(obj)
assert wire["id"] == "proc-1"
assert wire["document-id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["workspace"] == "alice"
assert wire["collection"] == "my-collection"
def test_missing_fields_use_defaults(self):
obj = self.tx.decode({})
assert obj.id is None
assert obj.user is None
assert obj.workspace is None
assert obj.collection is None
def test_tags_none_omitted(self):
@ -135,10 +135,10 @@ class TestProcessingMetadataTranslator:
wire = self.tx.encode(obj)
assert wire["tags"] == []
def test_user_and_collection_preserved(self):
def test_workspace_and_collection_preserved(self):
"""Core pipeline routing fields must survive round-trip."""
data = {"user": "bob", "collection": "research"}
data = {"workspace": "bob", "collection": "research"}
obj = self.tx.decode(data)
wire = self.tx.encode(obj)
assert wire["user"] == "bob"
assert wire["workspace"] == "bob"
assert wire["collection"] == "research"

View file

@ -61,7 +61,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
@ -69,7 +68,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = [] # Empty vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
# No upsert should be called
proc.qdrant.upsert.assert_not_called()
@ -83,7 +82,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
@ -91,7 +89,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = None # None vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -103,7 +101,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
@ -111,7 +108,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = [0.1, 0.2, 0.3]
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -124,7 +121,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
@ -132,7 +128,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = [0.1, 0.2, 0.3]
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
proc.qdrant.upsert.assert_called_once()
@pytest.mark.asyncio
@ -146,7 +142,6 @@ class TestDocEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "alice"
msg.metadata.collection = "docs"
emb = MagicMock()
@ -154,7 +149,7 @@ class TestDocEmbeddingsNullProtection:
emb.vector = [0.0] * 384 # 384-dim vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("alice", msg)
call_args = proc.qdrant.upsert.call_args
assert "d_alice_docs_384" in call_args[1]["collection_name"]
@ -175,7 +170,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
@ -183,7 +177,7 @@ class TestGraphEmbeddingsNullProtection:
entity.vector = [0.1, 0.2, 0.3]
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -195,7 +189,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
@ -203,7 +196,7 @@ class TestGraphEmbeddingsNullProtection:
entity.vector = [0.1, 0.2, 0.3]
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -215,7 +208,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
@ -223,7 +215,7 @@ class TestGraphEmbeddingsNullProtection:
entity.vector = [] # Empty vector
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -236,7 +228,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
@ -245,7 +236,7 @@ class TestGraphEmbeddingsNullProtection:
entity.chunk_id = "c1"
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_called_once()
@pytest.mark.asyncio
@ -258,7 +249,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "alice"
msg.metadata.collection = "graphs"
entity = MagicMock()
@ -267,7 +257,7 @@ class TestGraphEmbeddingsNullProtection:
entity.chunk_id = ""
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("alice", msg)
# Collection should be created with correct dimension
proc.qdrant.create_collection.assert_called_once()
@ -290,11 +280,10 @@ class TestCollectionValidation:
proc.collection_exists = MagicMock(return_value=False)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "deleted-col"
msg.chunks = [MagicMock()]
await proc.store_document_embeddings(msg)
await proc.store_document_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
@ -306,9 +295,8 @@ class TestCollectionValidation:
proc.collection_exists = MagicMock(return_value=False)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "deleted-col"
msg.entities = [MagicMock()]
await proc.store_graph_embeddings(msg)
await proc.store_graph_embeddings("user1", msg)
proc.qdrant.upsert.assert_not_called()

View file

@ -6,6 +6,7 @@ import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
from trustgraph.base import PromptResult
# Sample chunk content mapping for tests
@ -91,14 +92,13 @@ class TestQuery:
# Initialize Query with defaults
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
assert query.collection == "test_collection"
assert query.verbose is False
assert query.doc_limit == 20 # Default value
@ -111,7 +111,7 @@ class TestQuery:
# Initialize Query with custom doc_limit
query = Query(
rag=mock_rag,
user="custom_user",
workspace="test_workspace",
collection="custom_collection",
verbose=True,
doc_limit=50
@ -119,7 +119,6 @@ class TestQuery:
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.doc_limit == 50
@ -132,11 +131,11 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -157,11 +156,11 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -183,7 +182,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -222,7 +221,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=15
@ -239,7 +238,6 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
limit=15,
user="test_user",
collection="test_collection"
)
@ -258,7 +256,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
@ -273,7 +271,7 @@ class TestQuery:
expected_response = "This is the document RAG response"
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -285,7 +283,6 @@ class TestQuery:
result = await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=10
)
@ -303,7 +300,6 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
)
@ -315,7 +311,8 @@ class TestQuery:
assert "Relevant document content" in docs
assert "Another document" in docs
assert result == expected_response
result_text, usage = result
assert result_text == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
@ -325,7 +322,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
@ -333,7 +330,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -348,11 +345,11 @@ class TestQuery:
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
user="trustgraph", # Default user
collection="default" # Default collection
)
assert result == "Default response"
result_text, usage = result
assert result_text == "Default response"
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
@ -377,7 +374,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=True,
doc_limit=5
@ -401,7 +398,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
@ -409,7 +406,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -428,7 +425,8 @@ class TestQuery:
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
assert result == "Verbose RAG response"
result_text, usage = result
assert result_text == "Verbose RAG response"
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
@ -449,7 +447,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False
)
@ -469,11 +467,11 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -490,7 +488,8 @@ class TestQuery:
documents=[]
)
assert result == "No documents found response"
result_text, usage = result
assert result_text == "No documents found response"
@pytest.mark.asyncio
async def test_get_vectors_with_verbose(self):
@ -504,7 +503,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=True
)
@ -525,7 +524,7 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
# Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
@ -541,7 +540,7 @@ class TestQuery:
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -553,7 +552,6 @@ class TestQuery:
result = await document_rag.query(
query=query_text,
user="research_user",
collection="ml_knowledge",
doc_limit=25
)
@ -584,7 +582,8 @@ class TestQuery:
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
assert result == final_response
result_text, usage = result
assert result_text == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):
@ -613,7 +612,7 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=10

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -89,8 +90,8 @@ def build_mock_clients():
# 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return "return policy\nrefund"
return ""
return PromptResult(response_type="text", text="return policy\nrefund")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@ -113,8 +114,9 @@ def build_mock_clients():
fetch_chunk.side_effect = mock_fetch
# 5. Synthesis
prompt_client.document_prompt.return_value = (
"Items can be returned within 30 days for a full refund."
prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="Items can be returned within 30 days for a full refund.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(
result_text, usage = await rag.query(
query="What is the return policy?",
explain_callback=AsyncMock(),
)
assert result == "Items can be returned within 30 days for a full refund."
assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(query="What is the return policy?")
assert result == "Items can be returned within 30 days for a full refund."
result_text, usage = await rag.query(query="What is the return policy?")
assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):

View file

@ -1,6 +1,6 @@
"""
Unit test for DocumentRAG service parameter passing fix.
Tests that user and collection parameters from the message are correctly
Tests that the collection parameter from the message is correctly
passed to the DocumentRag.query() method.
"""
@ -16,13 +16,13 @@ class TestDocumentRagService:
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
@pytest.mark.asyncio
async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class):
async def test_collection_parameter_passed_to_query(self, mock_document_rag_class):
"""
Test that user and collection from message are passed to DocumentRag.query().
This is a regression test for the bug where user/collection parameters
were ignored, causing wrong collection names like 'd_trustgraph_default_384'
instead of 'd_my_user_test_coll_1_384'.
Test that collection from message is passed to DocumentRag.query().
This is a regression test for the bug where the collection parameter
was ignored, causing wrong collection names like 'd_trustgraph_default_384'
instead of one that reflects the requested collection.
"""
# Setup processor
processor = Processor(
@ -30,17 +30,16 @@ class TestDocumentRagService:
id="test-processor",
doc_limit=10
)
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "test response"
# Setup message with custom user/collection
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom collection
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="test query",
user="my_user", # Custom user (not default "trustgraph")
collection="test_coll_1", # Custom collection (not default "default")
doc_limit=5
)
@ -64,7 +63,7 @@ class TestDocumentRagService:
# Verify: DocumentRag.query was called with correct parameters
mock_rag_instance.query.assert_called_once_with(
"test query",
user="my_user", # Must be from message, not hardcoded default
workspace=ANY, # Workspace comes from flow.workspace (mock)
collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5,
explain_callback=ANY, # Explainability callback is always passed
@ -97,13 +96,12 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "A document about cats."
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
# Setup message with non-streaming request
msg = MagicMock()
msg.value.return_value = DocumentRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
doc_limit=10,
streaming=False # Non-streaming mode
@ -130,4 +128,5 @@ class TestDocumentRagService:
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "A document about cats."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.end_of_session is True
assert sent_response.error is None

View file

@ -7,6 +7,7 @@ import unittest.mock
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
from trustgraph.base import PromptResult
class TestGraphRag:
@ -77,14 +78,12 @@ class TestQuery:
# Initialize Query with defaults
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
assert query.collection == "test_collection"
assert query.verbose is False
assert query.entity_limit == 50 # Default value
@ -100,7 +99,6 @@ class TestQuery:
# Initialize Query with custom parameters
query = Query(
rag=mock_rag,
user="custom_user",
collection="custom_collection",
verbose=True,
entity_limit=100,
@ -111,7 +109,6 @@ class TestQuery:
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.entity_limit == 100
@ -132,7 +129,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -155,7 +151,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=True
)
@ -172,11 +167,10 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -196,11 +190,10 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -220,7 +213,7 @@ class TestQuery:
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
@ -243,7 +236,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
entity_limit=25
@ -268,7 +260,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -276,7 +267,7 @@ class TestQuery:
result = await query.maybe_label("entity1")
assert result == "Entity One Label"
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
mock_cache.get.assert_called_once_with("test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
@ -294,7 +285,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -306,13 +296,12 @@ class TestQuery:
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection",
g=""
)
assert result == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity"
cache_key = "test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@pytest.mark.asyncio
@ -329,7 +318,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -341,13 +329,12 @@ class TestQuery:
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection",
g=""
)
assert result == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity"
cache_key = "test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
@ -374,7 +361,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
triple_limit=10
@ -387,15 +373,15 @@ class TestQuery:
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
user="test_user", collection="test_collection", batch_size=20, g=""
collection="test_collection", batch_size=20, g=""
)
expected_subgraph = {
@ -414,7 +400,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
@ -434,7 +419,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_subgraph_size=2
@ -454,7 +438,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_path_length=1
@ -492,7 +475,6 @@ class TestQuery:
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
max_subgraph_size=100
@ -565,14 +547,14 @@ class TestQuery:
# Mock prompt responses for the multi-step process
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return json.dumps({"id": test_edge_id, "score": 0.9})
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
elif prompt_name == "kg-synthesis":
return expected_response
return ""
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt = mock_prompt
@ -600,14 +582,14 @@ class TestQuery:
try:
response = await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
entity_limit=25,
triple_limit=15,
explain_callback=collect_provenance
)
assert response == expected_response
response_text, usage = response
assert response_text == expected_response
# 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5

View file

@ -120,7 +120,6 @@ class TestGraphRagServiceExplainTriples:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is quantum computing?",
user="trustgraph",
collection="default",
streaming=False,
)

View file

@ -13,6 +13,7 @@ from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -136,24 +137,36 @@ def build_mock_clients():
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return prompt_responses["extract-concepts"]
return PromptResult(
response_type="text",
text=prompt_responses["extract-concepts"],
)
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
]
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
]
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-synthesis":
return synthesis_answer
return ""
return PromptResult(
response_type="text",
text=synthesis_answer,
)
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
result = await rag.query(
result_text, usage = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self):
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
clients = build_mock_clients()
rag = GraphRag(*clients)
result = await rag.query(
result_text, usage = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):

View file

@ -44,7 +44,7 @@ class TestGraphRagService:
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
await explain_callback([], "urn:trustgraph:prov:selection:test")
await explain_callback([], "urn:trustgraph:prov:answer:test")
return "A small domesticated mammal."
return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@ -52,7 +52,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
entity_limit=50,
triple_limit=30,
@ -79,8 +78,8 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
assert mock_response_producer.send.call_count == 6
# Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
assert mock_response_producer.send.call_count == 5
# First 4 messages are explain (emitted in real-time during query)
for i in range(4):
@ -88,17 +87,12 @@ class TestGraphRagService:
assert prov_msg.message_type == "explain"
assert prov_msg.explain_id is not None
# 5th message is chunk with response
# 5th message is chunk with response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "A small domesticated mammal."
assert chunk_msg.end_of_stream is True
# 6th message is empty chunk with end_of_session=True
close_msg = mock_response_producer.send.call_args_list[5][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
assert chunk_msg.end_of_session is True
# Verify provenance triples were sent to provenance queue
assert mock_provenance_producer.send.call_count == 4
@ -128,7 +122,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is a cat?",
user="trustgraph",
collection="default",
entity_limit=50,
triple_limit=30,
@ -187,7 +180,7 @@ class TestGraphRagService:
async def mock_query(**kwargs):
# Don't call explain_callback
return "Response text"
return "Response text", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@ -195,7 +188,6 @@ class TestGraphRagService:
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="Test query",
user="trustgraph",
collection="default",
streaming=False
)
@ -218,17 +210,12 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 2 messages (chunk + empty chunk to close)
assert mock_response_producer.send.call_count == 2
# Verify: 1 combined message (chunk with end_of_session)
assert mock_response_producer.send.call_count == 1
# First is the response chunk
# Single message has response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "Response text"
assert chunk_msg.end_of_stream is True
# Second is empty chunk to close session
close_msg = mock_response_producer.send.call_args_list[1][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
assert chunk_msg.end_of_session is True

View file

@ -72,7 +72,6 @@ def processor(mock_pulsar_client, sample_schemas):
return proc
@pytest.mark.asyncio
class TestNLPQueryProcessor:
"""Test NLP Query service processor"""
@ -287,11 +286,11 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert
assert "test_schema" in processor.schemas
schema = processor.schemas["test_schema"]
assert "test_schema" in processor.schemas["default"]
schema = processor.schemas["default"]["test_schema"]
assert schema.name == "test_schema"
assert schema.description == "Test schema"
assert len(schema.fields) == 2
@ -309,10 +308,10 @@ class TestNLPQueryProcessor:
}
# Act
await processor.on_schema_config(config, "v1")
await processor.on_schema_config("default", config, "v1")
# Assert - bad schema should be ignored
assert "bad_schema" not in processor.schemas
assert "bad_schema" not in processor.schemas.get("default", {})
def test_processor_initialization(self, mock_pulsar_client):
"""Test processor initialization with correct specifications"""

View file

@ -101,7 +101,7 @@ def service(mock_schemas):
taskgroup=MagicMock(),
id="test-processor"
)
service.schemas = mock_schemas
service.schemas = {"default": dict(mock_schemas)}
return service
@ -109,6 +109,7 @@ def service(mock_schemas):
def mock_flow():
"""Create mock flow with prompt service"""
flow = MagicMock()
flow.workspace = "default"
prompt_request_flow = AsyncMock()
flow.return_value.request = prompt_request_flow
return flow, prompt_request_flow

View file

@ -36,7 +36,6 @@ def processor(mock_pulsar_client):
return proc
@pytest.mark.asyncio
class TestStructuredQueryProcessor:
"""Test Structured Query service processor"""
@ -45,7 +44,6 @@ class TestStructuredQueryProcessor:
# Arrange
request = StructuredQueryRequest(
question="Show me all customers from New York",
user="trustgraph",
collection="default"
)
@ -111,7 +109,6 @@ class TestStructuredQueryProcessor:
assert isinstance(objects_call_args, RowsQueryRequest)
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
assert objects_call_args.variables == {"state": "NY"}
assert objects_call_args.user == "trustgraph"
assert objects_call_args.collection == "default"
# Verify response

View file

@ -17,7 +17,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
@ -80,7 +79,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -89,7 +87,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify insert was called once for the single chunk with its vector
processor.vecstore.insert.assert_called_once_with(
@ -99,14 +97,14 @@ class TestMilvusDocEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('test_workspace', mock_message)
# Verify insert was called once per chunk with user/collection parameters
# Verify insert was called once per chunk with workspace/collection parameters
expected_calls = [
# Chunk 1 - single vector
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_workspace', 'test_collection'),
# Chunk 2 - single vector
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_workspace', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
@ -122,7 +120,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -131,7 +128,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called for empty chunk
processor.vecstore.insert.assert_not_called()
@ -141,7 +138,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with None chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -150,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Note: Implementation passes through None chunk_ids (only skips empty string "")
processor.vecstore.insert.assert_called_once_with(
@ -162,7 +158,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with mix of valid and empty chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings(
@ -179,7 +174,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [valid_chunk, empty_chunk, another_valid]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify valid chunks were inserted, empty string chunk was skipped
expected_calls = [
@ -200,11 +195,10 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@ -214,7 +208,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -223,7 +216,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@ -233,7 +226,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each chunk has a single vector of different dimensions
@ -251,7 +243,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk1, chunk2, chunk3]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify all vectors were inserted regardless of dimension with user/collection parameters
expected_calls = [
@ -273,7 +265,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with Unicode content in chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -282,7 +273,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify Unicode chunk_id was stored correctly with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -294,7 +285,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with long chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a long chunk_id
@ -305,7 +295,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify long chunk_id was inserted with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -317,7 +307,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with whitespace-only chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -326,7 +315,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify whitespace content was inserted (not filtered out) with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -343,25 +332,24 @@ class TestMilvusDocEmbeddingsStorageProcessor:
('test@domain.com', 'test-collection.v1'),
]
for user, collection in test_cases:
for workspace, collection in test_cases:
processor.vecstore.reset_mock() # Reset mock for each test case
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = user
message.metadata.collection = collection
chunk = ChunkEmbeddings(
chunk_id="Test content",
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify insert was called with the correct user/collection
await processor.store_document_embeddings(workspace, message)
# Verify insert was called with the correct workspace/collection
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Test content", user, collection
[0.1, 0.2, 0.3], "Test content", workspace, collection
)
@pytest.mark.asyncio
@ -370,7 +358,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Store embeddings for user1/collection1
message1 = MagicMock()
message1.metadata = MagicMock()
message1.metadata.user = 'user1'
message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings(
chunk_id="User1 content",
@ -381,7 +368,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Store embeddings for user2/collection2
message2 = MagicMock()
message2.metadata = MagicMock()
message2.metadata.user = 'user2'
message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings(
chunk_id="User2 content",
@ -389,8 +375,8 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message2.chunks = [chunk2]
await processor.store_document_embeddings(message1)
await processor.store_document_embeddings(message2)
await processor.store_document_embeddings('user1', message1)
await processor.store_document_embeddings('user2', message2)
# Verify both calls were made with correct parameters
expected_calls = [
@ -411,18 +397,17 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with special characters in user/collection names"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'user@domain.com' # Email-like user
message.metadata.collection = 'test-collection.v1' # Collection with special chars
chunk = ChunkEmbeddings(
chunk_id="Special chars test",
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify the exact user/collection strings are passed (sanitization happens in DocVectors)
await processor.store_document_embeddings('user@domain.com', message)
# Verify the exact workspace/collection strings are passed (sanitization happens in DocVectors)
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1'
)

View file

@ -21,7 +21,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
@ -120,7 +119,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -135,7 +133,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index name and operations (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -185,7 +183,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that writing to non-existent index creates it lazily"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -200,7 +197,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created with correct dimension
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -217,7 +214,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -229,7 +225,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for empty chunk
mock_index.upsert.assert_not_called()
@ -239,7 +235,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -251,7 +246,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for None chunk
mock_index.upsert.assert_not_called()
@ -261,7 +256,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with chunk that decodes to empty string"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -273,7 +267,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for empty decoded chunk
mock_index.upsert.assert_not_called()
@ -283,7 +277,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each chunk has a single vector of different dimensions
@ -325,14 +318,13 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
@ -343,7 +335,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -355,7 +346,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@ -365,7 +356,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that lazy creation happens when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -380,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created
processor.pinecone.create_index.assert_called_once()
@ -390,7 +380,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that lazy creation works correctly"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -405,7 +394,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created and used
processor.pinecone.create_index.assert_called_once()
@ -416,7 +405,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -430,7 +418,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify Unicode content was properly decoded and stored
call_args = mock_index.upsert.call_args
@ -442,7 +430,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
@ -458,7 +445,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify large content was stored
call_args = mock_index.upsert.call_args

View file

@ -84,7 +84,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with chunks and vectors
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_chunk = MagicMock()
@ -94,7 +93,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('test_user', mock_message)
# Assert
# Verify collection existence was checked (with dimension suffix)
@ -138,7 +137,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple chunks
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_chunk1 = MagicMock()
@ -152,7 +150,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('multi_user', mock_message)
# Assert
# Should be called twice (once per chunk)
@ -198,7 +196,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple chunks, each having a single vector
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_chunk1 = MagicMock()
@ -216,7 +213,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('vector_user', mock_message)
# Assert
# Should be called 3 times (once per chunk)
@ -255,7 +252,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with empty chunk_id
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock()
@ -265,7 +261,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk_empty]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('empty_user', mock_message)
# Assert
# Should not call upsert for empty chunk_ids
@ -298,7 +294,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock()
@ -308,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('new_user', mock_message)
# Assert - collection should be lazily created
expected_collection = 'd_new_user_new_collection_5' # 5 dimensions
@ -350,7 +345,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'error_user'
mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock()
@ -361,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Act & Assert - should propagate the creation error
with pytest.raises(Exception, match="Connection error"):
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('error_user', mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@ -388,7 +382,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create first mock message
mock_message1 = MagicMock()
mock_message1.metadata.user = 'cache_user'
mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock()
@ -398,7 +391,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message1.chunks = [mock_chunk1]
# First call
await processor.store_document_embeddings(mock_message1)
await processor.store_document_embeddings('cache_user', mock_message1)
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
@ -406,7 +399,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create second mock message with same dimensions
mock_message2 = MagicMock()
mock_message2.metadata.user = 'cache_user'
mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock()
@ -416,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message2.chunks = [mock_chunk2]
# Act - Second call with same collection
await processor.store_document_embeddings(mock_message2)
await processor.store_document_embeddings('cache_user', mock_message2)
# Assert
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
@ -452,7 +444,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with chunks of different dimensions
mock_message = MagicMock()
mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection'
mock_chunk1 = MagicMock()
@ -466,7 +457,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('dim_user', mock_message)
# Assert
# Should check existence of DIFFERENT collections for each dimension
@ -526,7 +517,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with URI-style chunk_id
mock_message = MagicMock()
mock_message.metadata.user = 'uri_user'
mock_message.metadata.collection = 'uri_collection'
mock_chunk = MagicMock()
@ -536,7 +526,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('uri_user', mock_message)
# Assert
# Verify the chunk_id was stored correctly

View file

@ -17,7 +17,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entities with embeddings
@ -80,7 +79,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -89,7 +87,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify insert was called once with the full vector
processor.vecstore.insert.assert_called_once()
@ -102,14 +100,14 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('test_workspace', mock_message)
# Verify insert was called once per entity with user/collection parameters
# Verify insert was called once per entity with workspace/collection parameters
expected_calls = [
# Entity 1 - single vector
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_workspace', 'test_collection'),
# Entity 2 - single vector
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], 'literal entity', 'test_workspace', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
@ -125,7 +123,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -134,7 +131,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called for empty entity
processor.vecstore.insert.assert_not_called()
@ -144,7 +141,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -153,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called for None entity
processor.vecstore.insert.assert_not_called()
@ -163,7 +159,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with mix of valid and invalid entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings(
@ -183,7 +178,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [valid_entity, empty_entity, none_entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify only valid entity was inserted with user/collection/chunk_id parameters
processor.vecstore.insert.assert_called_once_with(
@ -196,11 +191,10 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@ -210,7 +204,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -219,7 +212,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@ -229,7 +222,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each entity has a single vector of different dimensions
@ -247,7 +239,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity1, entity2, entity3]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
@ -267,7 +259,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for both URI and literal entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings(
@ -280,7 +271,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [uri_entity, literal_entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify both entities were inserted
expected_calls = [

View file

@ -21,7 +21,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entity embeddings (each entity has a single vector)
@ -124,7 +123,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -139,7 +137,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1']):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index name and operations (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -189,7 +187,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that writing to non-existent index creates it lazily"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -204,7 +201,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created with correct dimension
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -221,7 +218,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -233,7 +229,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called for empty entity
mock_index.upsert.assert_not_called()
@ -243,7 +239,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -255,7 +250,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called for None entity
mock_index.upsert.assert_not_called()
@ -265,7 +260,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each entity has a single vector of different dimensions
@ -288,7 +282,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify different indexes were used for different dimensions
index_calls = processor.pinecone.Index.call_args_list
@ -307,14 +301,13 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
@ -325,7 +318,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -337,7 +329,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@ -347,7 +339,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that lazy creation happens when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -362,7 +353,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created
processor.pinecone.create_index.assert_called_once()
@ -372,7 +363,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that lazy creation works correctly"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -387,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created and used
processor.pinecone.create_index.assert_called_once()

View file

@ -64,7 +64,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with entities and vectors
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_entity = MagicMock()
@ -75,7 +74,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('test_user', mock_message)
# Assert
# Verify collection existence was checked (with dimension suffix)
@ -118,7 +117,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple entities
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_entity1 = MagicMock()
@ -134,7 +132,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity1, mock_entity2]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('multi_user', mock_message)
# Assert
# Should be called twice (once per entity)
@ -179,7 +177,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with three entities
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_entity1 = MagicMock()
@ -200,7 +197,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('vector_user', mock_message)
# Assert
# Should be called 3 times (once per entity)
@ -238,7 +235,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with empty entity value
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_entity_empty = MagicMock()
@ -253,7 +249,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity_empty, mock_entity_none]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('empty_user', mock_message)
# Assert
# Should not call upsert for empty entities

View file

@ -1,5 +1,5 @@
"""
Tests for Memgraph user/collection isolation in storage service
Tests for Memgraph workspace/collection isolation in storage service.
"""
import pytest
@ -8,47 +8,45 @@ from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor
class TestMemgraphUserCollectionIsolation:
"""Test cases for Memgraph storage service with user/collection isolation"""
class TestMemgraphWorkspaceCollectionIsolation:
"""Test cases for Memgraph storage service with workspace/collection isolation"""
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
"""Test that storage creates both legacy and user/collection indexes"""
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
"""Test that storage creates both legacy and workspace/collection indexes"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=MagicMock())
# Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total)
# 4 legacy + 4 workspace/collection = 8 total
assert mock_session.run.call_count == 8
# Check some specific index creation calls
expected_calls = [
"CREATE INDEX ON :Node",
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)",
"CREATE INDEX ON :Node(user)",
"CREATE INDEX ON :Node(workspace)",
"CREATE INDEX ON :Node(collection)",
"CREATE INDEX ON :Literal(user)",
"CREATE INDEX ON :Literal(workspace)",
"CREATE INDEX ON :Literal(collection)"
]
for expected_call in expected_calls:
mock_session.run.assert_any_call(expected_call)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_user_collection(self, mock_graph_db):
"""Test that store_triples includes user/collection in all operations"""
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
"""Test that store_triples includes workspace/collection in all operations"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -58,45 +56,39 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
# Create mock triple with URI object
from trustgraph.schema import IRI
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = IRI
triple.o.iri = "http://example.com/object"
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify user/collection parameters were passed to all operations
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
# create_node (subject), create_node (object), relate_node = 3 calls
assert mock_driver.execute_query.call_count == 3
# Check that user and collection were included in all calls
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert 'user' in call_kwargs
assert 'collection' in call_kwargs
assert call_kwargs['user'] == "test_user"
assert call_kwargs['collection'] == "test_collection"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
assert kwargs['workspace'] == "test_workspace"
assert kwargs['collection'] == "test_collection"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
"""Test that defaults are used when user/collection not provided in metadata"""
async def test_store_triples_with_default_collection(self, mock_graph_db):
"""Test that default collection is used when not provided in metadata"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -106,157 +98,151 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
# Create mock triple
from trustgraph.schema import IRI, LITERAL
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "literal_value"
triple.o.is_uri = False
# Create mock message without user/collection metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = None
mock_message.metadata.collection = None
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("default", mock_message)
# Verify defaults were used
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert call_kwargs['user'] == "default"
assert call_kwargs['collection'] == "default"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
assert kwargs['workspace'] == "default"
assert kwargs['collection'] == "default"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_node_includes_user_collection(self, mock_graph_db):
"""Test that create_node includes user/collection properties"""
def test_create_node_includes_workspace_collection(self, mock_graph_db):
"""Test that create_node includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.create_node("http://example.com/node", "test_user", "test_collection")
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/node",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_literal_includes_user_collection(self, mock_graph_db):
"""Test that create_literal includes user/collection properties"""
def test_create_literal_includes_workspace_collection(self, mock_graph_db):
"""Test that create_literal includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.create_literal("test_value", "test_user", "test_collection")
processor.create_literal("test_value", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="test_value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_relate_node_includes_user_collection(self, mock_graph_db):
"""Test that relate_node includes user/collection properties"""
def test_relate_node_includes_workspace_collection(self, mock_graph_db):
"""Test that relate_node includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.relate_node(
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/predicate",
"http://example.com/object",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_relate_literal_includes_user_collection(self, mock_graph_db):
"""Test that relate_literal includes user/collection properties"""
def test_relate_literal_includes_workspace_collection(self, mock_graph_db):
"""Test that relate_literal includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.relate_literal(
"http://example.com/subject",
"http://example.com/predicate",
"literal_value",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal_value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@ -264,20 +250,15 @@ class TestMemgraphUserCollectionIsolation:
def test_add_args_includes_memgraph_parameters(self):
"""Test that add_args properly configures Memgraph-specific parameters"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added with Memgraph defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://memgraph:7687'
assert hasattr(args, 'username')
@ -288,19 +269,18 @@ class TestMemgraphUserCollectionIsolation:
assert args.database == 'memgraph'
class TestMemgraphUserCollectionRegression:
"""Regression tests to ensure user/collection isolation prevents data leakage"""
class TestMemgraphWorkspaceCollectionRegression:
"""Regression tests to ensure workspace/collection isolation prevents data leakage"""
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
"""Regression test: Ensure users cannot access each other's data"""
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
"""Regression test: Ensure workspaces cannot access each other's data"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -310,60 +290,55 @@ class TestMemgraphUserCollectionRegression:
processor = Processor(taskgroup=MagicMock())
# Store data for user1
from trustgraph.schema import IRI, LITERAL
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "user1_data"
triple.o.is_uri = False
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "ws1_data"
message_user1 = MagicMock()
message_user1.triples = [triple]
message_user1.metadata.user = "user1"
message_user1.metadata.collection = "collection1"
message_ws1 = MagicMock()
message_ws1.triples = [triple]
message_ws1.metadata.collection = "collection1"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
await processor.store_triples("workspace1", message_ws1)
# Verify that all storage operations included user1/collection1 parameters
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
if 'user' in call_kwargs:
assert call_kwargs['user'] == "user1"
assert call_kwargs['collection'] == "collection1"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
if 'workspace' in kwargs:
assert kwargs['workspace'] == "workspace1"
assert kwargs['collection'] == "collection1"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_same_uri_different_users(self, mock_graph_db):
"""Regression test: Same URI can exist for different users without conflict"""
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
"""Regression test: Same URI can exist in different workspaces without conflict"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Same URI for different users should create separate nodes
processor.create_node("http://example.com/same-uri", "user1", "collection1")
processor.create_node("http://example.com/same-uri", "user2", "collection2")
# Verify both calls were made with different user/collection parameters
calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls
call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1]
call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1]
assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1"
assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2"
# Both should have the same URI but different user/collection
assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri"
processor.create_node("http://example.com/same-uri", "workspace1", "collection1")
processor.create_node("http://example.com/same-uri", "workspace2", "collection2")
calls = mock_driver.execute_query.call_args_list[-2:]
k1 = calls[0].kwargs
k2 = calls[1].kwargs
assert k1['workspace'] == "workspace1" and k1['collection'] == "collection1"
assert k2['workspace'] == "workspace2" and k2['collection'] == "collection2"
assert k1['uri'] == k2['uri'] == "http://example.com/same-uri"

View file

@ -1,5 +1,5 @@
"""
Tests for Neo4j user/collection isolation in triples storage and query
Tests for Neo4j workspace/collection isolation in triples storage and query.
"""
import pytest
@ -11,468 +11,406 @@ from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
from trustgraph.schema import TriplesQueryRequest
class TestNeo4jUserCollectionIsolation:
"""Test cases for Neo4j user/collection isolation functionality"""
class TestNeo4jWorkspaceCollectionIsolation:
"""Test cases for Neo4j workspace/collection isolation functionality"""
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
"""Test that storage service creates compound indexes for user/collection"""
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
"""Test that storage service creates compound indexes for workspace/collection"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Verify both legacy and new compound indexes are created
expected_indexes = [
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
]
# Check that all expected indexes were created
for expected_query in expected_indexes:
mock_session.run.assert_any_call(expected_query)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_user_collection(self, mock_graph_db):
"""Test that triples are stored with user/collection properties"""
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
"""Test that triples are stored with workspace/collection properties"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create test message with user/collection metadata
metadata = Metadata(
id="test-id",
user="test_user",
collection="test_collection"
)
metadata = Metadata(id="test-id", collection="test_collection")
triple = Triple(
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal_value")
)
message = Triples(
metadata=metadata,
triples=[triple]
)
# Mock execute_query to return summaries
message = Triples(metadata=metadata, triples=[triple])
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify nodes and relationships were created with user/collection properties
await processor.store_triples("test_workspace", message)
expected_calls = [
call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
),
call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="literal_value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
),
call(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal_value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
)
]
for expected_call in expected_calls:
mock_driver.execute_query.assert_any_call(*expected_call.args, **expected_call.kwargs)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
"""Test that default user/collection are used when not provided"""
async def test_store_triples_with_default_collection(self, mock_graph_db):
"""Test that default collection is used when not provided"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create test message without user/collection
metadata = Metadata(id="test-id")
triple = Triple(
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=IRI, iri="http://example.com/object")
)
message = Triples(
metadata=metadata,
triples=[triple]
)
# Mock execute_query
message = Triples(metadata=metadata, triples=[triple])
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify defaults were used
await processor.store_triples("default", message)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject",
user="default",
workspace="default",
collection="default",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_filters_by_user_collection(self, mock_graph_db):
"""Test that query service filters results by user/collection"""
async def test_query_triples_filters_by_workspace_collection(self, mock_graph_db):
"""Test that query service filters results by workspace/collection"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create test query
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None
)
# Mock query results
mock_records = [
MagicMock(data=lambda: {"dest": "http://example.com/object1"}),
MagicMock(data=lambda: {"dest": "literal_value"})
]
mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify queries include user/collection filters
await processor.query_triples("test_workspace", query)
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest"
)
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest"
)
# Check that queries were executed with user/collection parameters
calls = mock_driver.execute_query.call_args_list
assert any(
expected_literal_query in str(call) and
"user='test_user'" in str(call) and
"collection='test_collection'" in str(call)
for call in calls
expected_literal_query in str(c) and
"workspace='test_workspace'" in str(c) and
"collection='test_collection'" in str(c)
for c in calls
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_with_default_user_collection(self, mock_graph_db):
"""Test that query service uses defaults when user/collection not provided"""
async def test_query_triples_with_default_collection(self, mock_graph_db):
"""Test that query service uses default collection when not provided"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create test query without user/collection
query = TriplesQueryRequest(
s=None,
p=None,
o=None
)
# Mock empty results
query = TriplesQueryRequest(s=None, p=None, o=None)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify defaults were used in queries
await processor.query_triples("default", query)
calls = mock_driver.execute_query.call_args_list
assert any(
"user='default'" in str(call) and "collection='default'" in str(call)
for call in calls
"workspace='default'" in str(c) and "collection='default'" in str(c)
for c in calls
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_data_isolation_between_users(self, mock_graph_db):
"""Test that data from different users is properly isolated"""
async def test_data_isolation_between_workspaces(self, mock_graph_db):
"""Test that data from different workspaces is properly isolated"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create messages for different users
message_user1 = Triples(
metadata=Metadata(user="user1", collection="coll1"),
message_ws1 = Triples(
metadata=Metadata(collection="coll1"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.com/user1/subject"),
s=Term(type=IRI, iri="http://example.com/ws1/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user1_data")
o=Term(type=LITERAL, value="ws1_data")
)
]
)
message_user2 = Triples(
metadata=Metadata(user="user2", collection="coll2"),
message_ws2 = Triples(
metadata=Metadata(collection="coll2"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.com/user2/subject"),
s=Term(type=IRI, iri="http://example.com/ws2/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user2_data")
o=Term(type=LITERAL, value="ws2_data")
)
]
)
# Mock execute_query
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
# Store data for both users
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify user1 data was stored with user1/coll1
await processor.store_triples("workspace1", message_ws1)
await processor.store_triples("workspace2", message_ws2)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value="user1_data",
user="user1",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="ws1_data",
workspace="workspace1",
collection="coll1",
database_='neo4j'
)
# Verify user2 data was stored with user2/coll2
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value="user2_data",
user="user2",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="ws2_data",
workspace="workspace2",
collection="coll2",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_respects_user_collection(self, mock_graph_db):
"""Test that wildcard queries still filter by user/collection"""
async def test_wildcard_query_respects_workspace_collection(self, mock_graph_db):
"""Test that wildcard queries still filter by workspace/collection"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create wildcard query (all nulls) with user/collection
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
o=None
s=None, p=None, o=None,
)
# Mock results
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify wildcard queries include user/collection filters
await processor.query_triples("test_workspace", query)
wildcard_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
)
calls = mock_driver.execute_query.call_args_list
assert any(
wildcard_query in str(call) and
"user='test_user'" in str(call) and
"collection='test_collection'" in str(call)
for call in calls
wildcard_query in str(c) and
"workspace='test_workspace'" in str(c) and
"collection='test_collection'" in str(c)
for c in calls
)
def test_add_args_includes_neo4j_parameters(self):
"""Test that add_args includes Neo4j-specific parameters"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
StorageProcessor.add_args(parser)
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert hasattr(args, 'username')
assert hasattr(args, 'password')
assert hasattr(args, 'database')
# Check defaults
assert args.graph_host == 'bolt://neo4j:7687'
assert args.username == 'neo4j'
assert args.password == 'password'
assert args.database == 'neo4j'
class TestNeo4jUserCollectionRegression:
"""Regression tests to ensure user/collection isolation prevents data leaks"""
class TestNeo4jWorkspaceCollectionRegression:
"""Regression tests to ensure workspace/collection isolation prevents data leaks"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
@pytest.mark.asyncio
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
"""
Regression test: Ensure user1 cannot access user2's data
This test guards against the bug where all users shared the same
Neo4j graph space, causing data contamination between users.
Regression test: Ensure workspace1 cannot access workspace2's data.
Guards against a bug where all data shared the same Neo4j graph
space, causing data contamination between workspaces.
"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# User1 queries for all triples
query_user1 = TriplesQueryRequest(
user="user1",
query_ws1 = TriplesQueryRequest(
collection="collection1",
s=None, p=None, o=None
)
# Mock that the database has data but none matching user1/collection1
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query_user1)
# Verify empty results (user1 cannot see other users' data)
result = await processor.query_triples("workspace1", query_ws1)
assert len(result) == 0
# Verify the query included user/collection filters
calls = mock_driver.execute_query.call_args_list
for call in calls:
query_str = str(call)
for c in calls:
query_str = str(c)
if "MATCH" in query_str:
assert "user: $user" in query_str or "user='user1'" in query_str
assert "workspace: $workspace" in query_str or "workspace='workspace1'" in query_str
assert "collection: $collection" in query_str or "collection='collection1'" in query_str
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_same_uri_different_users(self, mock_graph_db):
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
"""
Regression test: Same URI in different user contexts should create separate nodes
This ensures that http://example.com/entity for user1 is completely separate
from http://example.com/entity for user2.
Regression test: Same URI in different workspace contexts should create separate nodes.
Ensures http://example.com/entity in workspace1 is completely
separate from the same URI in workspace2.
"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Same URI for different users
shared_uri = "http://example.com/shared_entity"
message_user1 = Triples(
metadata=Metadata(user="user1", collection="coll1"),
message_ws1 = Triples(
metadata=Metadata(collection="coll1"),
triples=[
Triple(
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user1_value")
o=Term(type=LITERAL, value="ws1_value")
)
]
)
message_user2 = Triples(
metadata=Metadata(user="user2", collection="coll2"),
message_ws2 = Triples(
metadata=Metadata(collection="coll2"),
triples=[
Triple(
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user2_value")
o=Term(type=LITERAL, value="ws2_value")
)
]
)
# Mock execute_query
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify two separate nodes were created with same URI but different user/collection
user1_node_call = call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
await processor.store_triples("workspace1", message_ws1)
await processor.store_triples("workspace2", message_ws2)
ws1_node_call = call(
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=shared_uri,
user="user1",
workspace="workspace1",
collection="coll1",
database_='neo4j'
)
user2_node_call = call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
ws2_node_call = call(
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=shared_uri,
user="user2",
workspace="workspace2",
collection="coll2",
database_='neo4j'
)
mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True)
mock_driver.execute_query.assert_has_calls([ws1_node_call, ws2_node_call], any_order=True)

View file

@ -1,3 +1,12 @@
def _flow_mock(workspace):
"""Build a mock flow object that is callable and exposes .workspace."""
from unittest.mock import MagicMock
f = MagicMock()
f.workspace = workspace
return f
"""
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
@ -92,13 +101,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
collection_name = processor.get_collection_name(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="customer_data",
dimension=384
)
assert collection_name == "rows_test_user_test_collection_customer_data_384"
assert collection_name == "rows_test_workspace_test_collection_customer_data_384"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
@ -185,11 +194,10 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
# Create embeddings message
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -210,14 +218,14 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
assert upsert_call_args[1]['collection_name'] == 'rows_test_workspace_test_collection_customers_3'
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
@ -243,10 +251,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -267,7 +274,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should be called once for the single embedding
assert mock_qdrant_instance.upsert.call_count == 1
@ -287,10 +294,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -311,7 +317,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should not call upsert for empty vectors
mock_qdrant_instance.upsert.assert_not_called()
@ -334,7 +340,6 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
@ -354,7 +359,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should not call upsert for unknown collection
mock_qdrant_instance.upsert.assert_not_called()
@ -368,11 +373,11 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
# Mock collections list
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
mock_coll1.name = 'rows_test_workspace_test_collection_schema1_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
mock_coll2.name = 'rows_test_workspace_test_collection_schema2_384'
mock_coll3 = MagicMock()
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
mock_coll3.name = 'rows_other_workspace_other_collection_schema_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
@ -386,15 +391,15 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
await processor.delete_collection('test_user', 'test_collection')
await processor.delete_collection('test_workspace', 'test_collection')
# Should delete only the matching collections
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):
@ -404,9 +409,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock()
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
mock_coll1.name = 'rows_test_workspace_test_collection_customers_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
mock_coll2.name = 'rows_test_workspace_test_collection_orders_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2]
@ -422,13 +427,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
await processor.delete_collection_schema(
'test_user', 'test_collection', 'customers'
'test_workspace', 'test_collection', 'customers'
)
# Should only delete the customers schema collection
mock_qdrant_instance.delete_collection.assert_called_once()
call_args = mock_qdrant_instance.delete_collection.call_args[0]
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
assert call_args[0] == 'rows_test_workspace_test_collection_customers_384'
if __name__ == '__main__':

View file

@ -17,6 +17,17 @@ from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class _MockFlowDefault:
"""Mock Flow with default workspace for testing."""
workspace = "default"
name = "default"
id = "test-processor"
mock_flow_default = _MockFlowDefault()
class TestRowsCassandraStorageLogic:
"""Test business logic for unified table implementation"""
@ -145,11 +156,11 @@ class TestRowsCassandraStorageLogic:
}
# Process configuration
await processor.on_schema_config(config, version=1)
await processor.on_schema_config("default", config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert "customer_records" in processor.schemas["default"]
schema = processor.schemas["default"]["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
@ -160,20 +171,23 @@ class TestRowsCassandraStorageLogic:
assert id_field.primary is True
@pytest.mark.asyncio
async def test_object_processing_stores_data_map(self):
@patch('trustgraph.storage.rows.cassandra.write.async_execute', new_callable=AsyncMock)
async def test_object_processing_stores_data_map(self, mock_async_execute):
"""Test that row processing stores data as map<text, text>"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
"default": {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -184,11 +198,12 @@ class TestRowsCassandraStorageLogic:
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
mock_async_execute.return_value = []
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
),
schema_name="test_schema",
@ -202,16 +217,16 @@ class TestRowsCassandraStorageLogic:
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify insert was executed
processor.session.execute.assert_called()
insert_call = processor.session.execute.call_args
insert_cql = insert_call[0][0]
values = insert_call[0][1]
mock_async_execute.assert_called()
insert_call = mock_async_execute.call_args
insert_cql = insert_call[0][1]
values = insert_call[0][2]
# Verify using unified rows table
assert "INSERT INTO test_user.rows" in insert_cql
assert "INSERT INTO default.rows" in insert_cql
# Values should be: (collection, schema_name, index_name, index_value, data, source)
assert values[0] == "test_collection" # collection
@ -222,20 +237,23 @@ class TestRowsCassandraStorageLogic:
assert values[5] == "" # source
@pytest.mark.asyncio
async def test_object_processing_multiple_indexes(self):
@patch('trustgraph.storage.rows.cassandra.write.async_execute', new_callable=AsyncMock)
async def test_object_processing_multiple_indexes(self, mock_async_execute):
"""Test that row is written once per indexed field"""
processor = MagicMock()
processor.schemas = {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
"default": {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -246,10 +264,11 @@ class TestRowsCassandraStorageLogic:
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
mock_async_execute.return_value = []
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
),
schema_name="multi_index_schema",
@ -261,15 +280,15 @@ class TestRowsCassandraStorageLogic:
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 inserts (one per indexed field: id, category, status)
assert processor.session.execute.call_count == 3
assert mock_async_execute.call_count == 3
# Check that different index_names were used
index_names_used = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
for call in mock_async_execute.call_args_list:
values = call[0][2]
index_names_used.add(values[2]) # index_name is 3rd value
assert index_names_used == {"id", "category", "status"}
@ -279,19 +298,22 @@ class TestRowsCassandraStorageBatchLogic:
"""Test batch processing logic for unified table implementation"""
@pytest.mark.asyncio
async def test_batch_object_processing(self):
@patch('trustgraph.storage.rows.cassandra.write.async_execute', new_callable=AsyncMock)
async def test_batch_object_processing(self, mock_async_execute):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
"default": {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -302,11 +324,12 @@ class TestRowsCassandraStorageBatchLogic:
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
mock_async_execute.return_value = []
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
),
schema_name="batch_schema",
@ -322,15 +345,15 @@ class TestRowsCassandraStorageBatchLogic:
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Should have 3 inserts (one per row, one index per row since only primary key)
assert processor.session.execute.call_count == 3
assert mock_async_execute.call_count == 3
# Check each insert has different id
ids_inserted = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
for call in mock_async_execute.call_args_list:
values = call[0][2]
ids_inserted.add(tuple(values[3])) # index_value is 4th value
assert ids_inserted == {("001",), ("002",), ("003",)}
@ -340,12 +363,14 @@ class TestRowsCassandraStorageBatchLogic:
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
"default": {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -360,7 +385,6 @@ class TestRowsCassandraStorageBatchLogic:
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
),
schema_name="empty_schema",
@ -372,7 +396,7 @@ class TestRowsCassandraStorageBatchLogic:
msg = MagicMock()
msg.value.return_value = empty_batch_obj
await processor.on_object(msg, None, None)
await processor.on_object(msg, None, mock_flow_default)
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@ -437,19 +461,21 @@ class TestPartitionRegistration:
processor.registered_partitions = set()
processor.session = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
"default": {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
}
}
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
# Should have 2 inserts (one per index: id, category)
assert processor.session.execute.call_count == 2
@ -464,7 +490,7 @@ class TestPartitionRegistration:
processor.session = MagicMock()
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
processor.register_partitions("test_user", "test_collection", "test_schema", "default")
# Should not execute any CQL since already registered
processor.session.execute.assert_not_called()

View file

@ -102,11 +102,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph was called with auth parameters
mock_kg_class.assert_called_once_with(
@ -129,11 +128,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user2'
mock_message.metadata.collection = 'collection2'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user2', mock_message)
# Verify KnowledgeGraph was called without auth parameters
mock_kg_class.assert_called_once_with(
@ -154,16 +152,15 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
# First call should create TrustGraph
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
assert mock_kg_class.call_count == 1
# Second call with same table should reuse TrustGraph
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@ -205,11 +202,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2
@ -234,11 +230,10 @@ class TestCassandraStorageProcessor:
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify no triples were inserted
mock_tg_instance.insert.assert_not_called()
@ -255,12 +250,11 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
@ -361,21 +355,19 @@ class TestCassandraStorageProcessor:
# First message with table1
mock_message1 = MagicMock()
mock_message1.metadata.user = 'user1'
mock_message1.metadata.collection = 'collection1'
mock_message1.triples = []
await processor.store_triples(mock_message1)
await processor.store_triples('user1', mock_message1)
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table
mock_message2 = MagicMock()
mock_message2.metadata.user = 'user2'
mock_message2.metadata.collection = 'collection2'
mock_message2.triples = []
await processor.store_triples(mock_message2)
await processor.store_triples('user2', mock_message2)
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
@ -407,11 +399,10 @@ class TestCassandraStorageProcessor:
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
await processor.store_triples('test_workspace', mock_message)
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
@ -440,12 +431,11 @@ class TestCassandraStorageProcessor:
mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
await processor.store_triples('new_user', mock_message)
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
@ -468,11 +458,10 @@ class TestCassandraPerformanceOptimizations:
processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph instance uses legacy mode
assert mock_tg_instance is not None
@ -489,11 +478,10 @@ class TestCassandraPerformanceOptimizations:
processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph instance is in optimized mode
assert mock_tg_instance is not None
@ -523,11 +511,10 @@ class TestCassandraPerformanceOptimizations:
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with(

View file

@ -17,7 +17,6 @@ class TestFalkorDBStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
@ -89,13 +88,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_node(test_uri, 'test_user', 'test_collection')
processor.create_node(test_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": test_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -109,13 +108,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_literal(test_value, 'test_user', 'test_collection')
processor.create_literal(test_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": test_value,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -132,17 +131,17 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection')
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src_uri,
"dest": dest_uri,
"uri": pred_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -159,17 +158,17 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection')
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src_uri,
"dest": literal_value,
"uri": pred_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -179,7 +178,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing triple with URI object"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple = Triple(
@ -200,21 +198,21 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create object node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create relationship
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
]
assert processor.io.query.call_count == 3
@ -237,21 +235,21 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('test_workspace', mock_message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create literal object
(("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",),
{"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",),
{"params": {"value": "literal object", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create relationship
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
]
assert processor.io.query.call_count == 3
@ -265,7 +263,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing multiple triples"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -291,7 +288,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
@ -313,7 +310,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing empty triples list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
@ -323,7 +319,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify no queries were made
processor.io.query.assert_not_called()
@ -333,7 +329,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing triples with mixed URI and literal objects"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -359,7 +354,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
@ -450,13 +445,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_node(test_uri, 'test_user', 'test_collection')
processor.create_node(test_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": test_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -470,13 +465,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_literal(test_value, 'test_user', 'test_collection')
processor.create_literal(test_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": test_value,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)

View file

@ -17,7 +17,6 @@ class TestMemgraphStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
@ -43,7 +42,7 @@ class TestMemgraphStorageProcessor:
taskgroup=MagicMock(),
id='test-memgraph-storage',
graph_host='bolt://localhost:7687',
username='test_user',
username='test_workspace',
password='test_pass',
database='test_db'
)
@ -105,9 +104,9 @@ class TestMemgraphStorageProcessor:
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)",
"CREATE INDEX ON :Node(user)",
"CREATE INDEX ON :Node(workspace)",
"CREATE INDEX ON :Node(collection)",
"CREATE INDEX ON :Literal(user)",
"CREATE INDEX ON :Literal(workspace)",
"CREATE INDEX ON :Literal(collection)"
]
@ -145,12 +144,12 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.create_node(test_uri, "test_user", "test_collection")
processor.create_node(test_uri, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=test_uri,
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_=processor.db
)
@ -166,12 +165,12 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.create_literal(test_value, "test_user", "test_collection")
processor.create_literal(test_value, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=test_value,
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_=processor.db
)
@ -190,14 +189,14 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri, "test_user", "test_collection")
processor.relate_node(src_uri, pred_uri, dest_uri, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src_uri, dest=dest_uri, uri=pred_uri,
user="test_user", collection="test_collection",
workspace="test_workspace", collection="test_collection",
database_=processor.db
)
@ -215,14 +214,14 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value, "test_user", "test_collection")
processor.relate_literal(src_uri, pred_uri, literal_value, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src_uri, dest=literal_value, uri=pred_uri,
user="test_user", collection="test_collection",
workspace="test_workspace", collection="test_collection",
database_=processor.db
)
@ -236,22 +235,22 @@ class TestMemgraphStorageProcessor:
o=Term(type=IRI, iri='http://example.com/object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create object node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/object', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create relationship
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate',
'user': 'test_user', 'collection': 'test_collection'})
'workspace': 'test_workspace', 'collection': 'test_collection'})
]
assert mock_tx.run.call_count == 3
@ -270,22 +269,22 @@ class TestMemgraphStorageProcessor:
o=Term(type=LITERAL, value='literal object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create literal object
("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
{'value': 'literal object', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
{'value': 'literal object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create relationship
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate',
'user': 'test_user', 'collection': 'test_collection'})
'workspace': 'test_workspace', 'collection': 'test_collection'})
]
assert mock_tx.run.call_count == 3
@ -314,8 +313,8 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples('test_workspace', mock_message)
# Verify execute_query was called for create_node, create_literal, and relate_literal
# (since mock_message has a literal object)
assert processor.io.execute_query.call_count == 3
@ -323,7 +322,7 @@ class TestMemgraphStorageProcessor:
# Verify user/collection parameters were included
for call in processor.io.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert 'user' in call_kwargs
assert 'workspace' in call_kwargs
assert 'collection' in call_kwargs
@pytest.mark.asyncio
@ -343,7 +342,6 @@ class TestMemgraphStorageProcessor:
# Create message with multiple triples
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -364,7 +362,7 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify execute_query was called:
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
@ -375,7 +373,7 @@ class TestMemgraphStorageProcessor:
# Verify user/collection parameters were included in all calls
for call in processor.io.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert call_kwargs['user'] == 'test_user'
assert call_kwargs['workspace'] == 'test_workspace'
assert call_kwargs['collection'] == 'test_collection'
@pytest.mark.asyncio
@ -389,7 +387,6 @@ class TestMemgraphStorageProcessor:
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
@ -399,7 +396,7 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called()

View file

@ -68,9 +68,9 @@ class TestNeo4jStorageProcessor:
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
]
@ -116,12 +116,12 @@ class TestNeo4jStorageProcessor:
processor = Processor(taskgroup=taskgroup_mock)
# Test create_node
processor.create_node("http://example.com/node", "test_user", "test_collection")
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/node",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -146,12 +146,12 @@ class TestNeo4jStorageProcessor:
processor = Processor(taskgroup=taskgroup_mock)
# Test create_literal
processor.create_literal("literal value", "test_user", "test_collection")
processor.create_literal("literal value", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="literal value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -180,18 +180,18 @@ class TestNeo4jStorageProcessor:
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/object",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -220,18 +220,18 @@ class TestNeo4jStorageProcessor:
"http://example.com/subject",
"http://example.com/predicate",
"literal value",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -268,36 +268,35 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify create_node was called for subject and object
# Verify relate_node was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Object node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "http://example.com/object",
"uri": "http://example.com/predicate",
"user": "test_user",
"workspace": "test_workspace",
"collection": "test_collection",
"database_": "neo4j"
}
@ -340,12 +339,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify create_node was called for subject
# Verify create_literal was called for object
@ -353,24 +351,24 @@ class TestNeo4jStorageProcessor:
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Literal creation
(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
{"value": "literal value", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
{"value": "literal value", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "literal value",
"uri": "http://example.com/predicate",
"user": "test_user",
"workspace": "test_workspace",
"collection": "test_collection",
"database_": "neo4j"
}
@ -421,12 +419,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple1, triple2]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls
@ -449,12 +446,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with empty triples and metadata
mock_message = MagicMock()
mock_message.triples = []
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization
@ -568,38 +564,37 @@ class TestNeo4jStorageProcessor:
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject with spaces",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value='literal with "quotes" and unicode: ñáéíóú',
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject with spaces",
dest='literal with "quotes" and unicode: ñáéíóú',
uri="http://example.com/predicate:with/symbols",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)

View file

@ -24,11 +24,10 @@ def _make_processor(qdrant_client=None):
return proc
def _make_request(vector=None, user="test-user", collection="test-col",
def _make_request(vector=None, collection="test-col",
schema_name="customers", limit=10, index_name=None):
return RowEmbeddingsRequest(
vector=vector or [0.1, 0.2, 0.3],
user=user,
collection=collection,
schema_name=schema_name,
limit=limit,
@ -36,6 +35,14 @@ def _make_request(vector=None, user="test-user", collection="test-col",
)
def _make_flow(workspace="test-workspace", pub=None):
"""Make a mock flow object that is callable and has .workspace."""
flow = MagicMock()
flow.return_value = pub if pub is not None else AsyncMock()
flow.workspace = workspace
return flow
def _make_search_point(index_name, index_value, text, score):
point = MagicMock()
point.payload = {
@ -85,34 +92,33 @@ class TestFindCollection:
def test_finds_matching_collection(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_test_user_test_col_customers_384"
mock_coll.name = "rows_test_workspace_test_col_customers_384"
mock_collections = MagicMock()
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers")
result = proc.find_collection("test-workspace", "test-col", "customers")
# Prefix: rows_test_user_test_col_customers_
assert result == "rows_test_user_test_col_customers_384"
assert result == "rows_test_workspace_test_col_customers_384"
def test_returns_none_when_no_match(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_other_user_other_col_schema_768"
mock_coll.name = "rows_other_workspace_other_col_schema_768"
mock_collections = MagicMock()
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers")
result = proc.find_collection("test-workspace", "test-col", "customers")
assert result is None
def test_returns_none_on_error(self):
proc = _make_processor()
proc.qdrant.get_collections.side_effect = Exception("connection error")
result = proc.find_collection("user", "col", "schema")
result = proc.find_collection("workspace", "col", "schema")
assert result is None
@ -127,7 +133,7 @@ class TestQueryRowEmbeddings:
proc = _make_processor()
request = _make_request(vector=[])
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert result == []
@pytest.mark.asyncio
@ -136,13 +142,13 @@ class TestQueryRowEmbeddings:
proc.find_collection = MagicMock(return_value=None)
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert result == []
@pytest.mark.asyncio
async def test_successful_query_returns_matches(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
points = [
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
@ -153,7 +159,7 @@ class TestQueryRowEmbeddings:
proc.qdrant.query_points.return_value = mock_result
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert len(result) == 2
assert isinstance(result[0], RowIndexMatch)
@ -166,14 +172,14 @@ class TestQueryRowEmbeddings:
async def test_index_name_filter_applied(self):
"""When index_name is specified, a Qdrant filter should be used."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="address")
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is not None
@ -182,14 +188,14 @@ class TestQueryRowEmbeddings:
async def test_no_index_name_no_filter(self):
"""When index_name is empty, no filter should be applied."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="")
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is None
@ -198,7 +204,7 @@ class TestQueryRowEmbeddings:
async def test_missing_payload_fields_default(self):
"""Points with missing payload fields should use defaults."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
point = MagicMock()
point.payload = {} # Empty payload
@ -209,7 +215,7 @@ class TestQueryRowEmbeddings:
proc.qdrant.query_points.return_value = mock_result
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert len(result) == 1
assert result[0].index_name == ""
@ -219,13 +225,13 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_qdrant_error_propagates(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.qdrant.query_points.side_effect = Exception("qdrant down")
request = _make_request()
with pytest.raises(Exception, match="qdrant down"):
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
# ---------------------------------------------------------------------------
@ -243,7 +249,7 @@ class TestOnMessage:
])
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()
@ -264,7 +270,7 @@ class TestOnMessage:
)
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()
@ -284,7 +290,7 @@ class TestOnMessage:
proc.query_row_embeddings = AsyncMock(return_value=[])
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()

Some files were not shown because too many files have changed in this diff Show more