diff --git a/tests/unit/test_agent/test_tool_service_lifecycle.py b/tests/unit/test_agent/test_tool_service_lifecycle.py new file mode 100644 index 00000000..65cdb542 --- /dev/null +++ b/tests/unit/test_agent/test_tool_service_lifecycle.py @@ -0,0 +1,624 @@ +""" +Tests for tool service lifecycle, invoke contract, streaming responses, +multi-tenancy, and error propagation. + +Tests the actual DynamicToolService, ToolService, and ToolServiceClient +classes rather than plain dicts. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from trustgraph.schema import ( + ToolServiceRequest, ToolServiceResponse, Error, + ToolRequest, ToolResponse, +) +from trustgraph.exceptions import TooManyRequests + + +# --------------------------------------------------------------------------- +# DynamicToolService tests +# --------------------------------------------------------------------------- + +class TestDynamicToolServiceInvokeContract: + + @pytest.mark.asyncio + async def test_base_invoke_raises_not_implemented(self): + """Base class invoke() should raise NotImplementedError.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + + with pytest.raises(NotImplementedError): + await svc.invoke("user", {}, {}) + + @pytest.mark.asyncio + async def test_on_request_calls_invoke_with_parsed_args(self): + """on_request should JSON-parse config/arguments and pass to invoke.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + calls = [] + + async def tracking_invoke(user, config, arguments): + calls.append({"user": user, "config": config, "arguments": arguments}) + return "ok" + + svc.invoke = tracking_invoke + + # Ensure the class-level metric exists + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest( + user="alice", + config='{"style": "pun"}', + arguments='{"topic": "cats"}', + ) + msg.properties.return_value = {"id": "req-1"} + + await svc.on_request(msg, MagicMock(), None) + + assert len(calls) == 1 + assert calls[0]["user"] == "alice" + assert calls[0]["config"] == {"style": "pun"} + assert calls[0]["arguments"] == {"topic": "cats"} + + @pytest.mark.asyncio + async def test_on_request_empty_user_defaults_to_trustgraph(self): + """Empty user field should default to 'trustgraph'.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + received_user = None + + async def capture_invoke(user, config, arguments): + nonlocal received_user + received_user = user + return "ok" + + svc.invoke = capture_invoke + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="", config="", arguments="") + msg.properties.return_value = {"id": "req-2"} + + await svc.on_request(msg, MagicMock(), None) + + assert received_user == "trustgraph" + + @pytest.mark.asyncio + async def test_on_request_string_response_sent_directly(self): + """String return from invoke → response field is the string.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def string_invoke(user, config, arguments): + return "hello world" + + svc.invoke = string_invoke + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "r1"} + + await svc.on_request(msg, MagicMock(), None) + + sent = svc.producer.send.call_args[0][0] + assert isinstance(sent, ToolServiceResponse) + assert sent.response == "hello world" + assert sent.end_of_stream is True + assert sent.error is None + + @pytest.mark.asyncio + async def test_on_request_dict_response_json_encoded(self): + """Dict return from invoke → response field is JSON-encoded.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def dict_invoke(user, config, arguments): + return {"result": 42} + + svc.invoke = dict_invoke + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "r2"} + + await svc.on_request(msg, MagicMock(), None) + + sent = svc.producer.send.call_args[0][0] + assert json.loads(sent.response) == {"result": 42} + + @pytest.mark.asyncio + async def test_on_request_error_sends_error_response(self): + """Exception in invoke → error response sent.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def failing_invoke(user, config, arguments): + raise ValueError("bad input") + + svc.invoke = failing_invoke + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "r3"} + + await svc.on_request(msg, MagicMock(), None) + + sent = svc.producer.send.call_args[0][0] + assert sent.error is not None + assert sent.error.type == "tool-service-error" + assert "bad input" in sent.error.message + assert sent.response == "" + + @pytest.mark.asyncio + async def test_on_request_too_many_requests_propagates(self): + """TooManyRequests should propagate (not caught as error response).""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def rate_limited_invoke(user, config, arguments): + raise TooManyRequests("rate limited") + + svc.invoke = rate_limited_invoke + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "r4"} + + with pytest.raises(TooManyRequests): + await svc.on_request(msg, MagicMock(), None) + + @pytest.mark.asyncio + async def test_on_request_preserves_message_id(self): + """Response should include the original message id in properties.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test-svc" + svc.producer = AsyncMock() + + async def ok_invoke(user, config, arguments): + return "ok" + + svc.invoke = ok_invoke + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + msg = MagicMock() + msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}") + msg.properties.return_value = {"id": "unique-42"} + + await svc.on_request(msg, MagicMock(), None) + + props = svc.producer.send.call_args[1]["properties"] + assert props["id"] == "unique-42" + + +# --------------------------------------------------------------------------- +# ToolService (flow-based) tests +# --------------------------------------------------------------------------- + +class TestToolServiceOnRequest: + + @pytest.mark.asyncio + async def test_string_response_sent_as_text(self): + """String return from invoke_tool → ToolResponse.text is set.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + async def mock_invoke(name, params): + return "tool result" + + svc.invoke_tool = mock_invoke + + if not hasattr(ToolService, "tool_invocation_metric"): + ToolService.tool_invocation_metric = MagicMock() + + mock_response_pub = AsyncMock() + flow = MagicMock() + flow.name = "test-flow" + + def flow_callable(name): + if name == "response": + return mock_response_pub + return MagicMock() + + flow_callable.producer = {"response": mock_response_pub} + flow_callable.name = "test-flow" + + msg = MagicMock() + msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}') + msg.properties.return_value = {"id": "t1"} + + await svc.on_request(msg, MagicMock(), flow_callable) + + sent = mock_response_pub.send.call_args[0][0] + assert isinstance(sent, ToolResponse) + assert sent.text == "tool result" + assert sent.object is None + + @pytest.mark.asyncio + async def test_dict_response_sent_as_json_object(self): + """Dict return from invoke_tool → ToolResponse.object is JSON.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + async def mock_invoke(name, params): + return {"data": [1, 2, 3]} + + svc.invoke_tool = mock_invoke + + if not hasattr(ToolService, "tool_invocation_metric"): + ToolService.tool_invocation_metric = MagicMock() + + mock_response_pub = AsyncMock() + flow = MagicMock() + + def flow_callable(name): + if name == "response": + return mock_response_pub + return MagicMock() + + flow_callable.producer = {"response": mock_response_pub} + flow_callable.name = "test-flow" + + msg = MagicMock() + msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") + msg.properties.return_value = {"id": "t2"} + + await svc.on_request(msg, MagicMock(), flow_callable) + + sent = mock_response_pub.send.call_args[0][0] + assert sent.text is None + assert json.loads(sent.object) == {"data": [1, 2, 3]} + + @pytest.mark.asyncio + async def test_error_sends_error_response(self): + """Exception in invoke_tool → error response via flow producer.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + async def failing_invoke(name, params): + raise RuntimeError("tool broke") + + svc.invoke_tool = failing_invoke + + mock_response_pub = AsyncMock() + flow = MagicMock() + + def flow_callable(name): + return MagicMock() + + flow_callable.producer = {"response": mock_response_pub} + flow_callable.name = "test-flow" + + msg = MagicMock() + msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") + msg.properties.return_value = {"id": "t3"} + + await svc.on_request(msg, MagicMock(), flow_callable) + + sent = mock_response_pub.send.call_args[0][0] + assert sent.error is not None + assert sent.error.type == "tool-error" + assert "tool broke" in sent.error.message + + @pytest.mark.asyncio + async def test_too_many_requests_propagates(self): + """TooManyRequests should propagate from ToolService.on_request.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + async def rate_limited(name, params): + raise TooManyRequests("slow down") + + svc.invoke_tool = rate_limited + + msg = MagicMock() + msg.value.return_value = ToolRequest(name="my-tool", parameters="{}") + msg.properties.return_value = {"id": "t4"} + + flow = MagicMock() + flow.producer = {"response": AsyncMock()} + flow.name = "test-flow" + + with pytest.raises(TooManyRequests): + await svc.on_request(msg, MagicMock(), flow) + + @pytest.mark.asyncio + async def test_parameters_json_parsed(self): + """Parameters should be JSON-parsed before passing to invoke_tool.""" + from trustgraph.base.tool_service import ToolService + + svc = ToolService.__new__(ToolService) + svc.id = "test-tool" + + received = {} + + async def capture_invoke(name, params): + received["name"] = name + received["params"] = params + return "ok" + + svc.invoke_tool = capture_invoke + + if not hasattr(ToolService, "tool_invocation_metric"): + ToolService.tool_invocation_metric = MagicMock() + + mock_pub = AsyncMock() + flow = lambda name: mock_pub + flow.producer = {"response": mock_pub} + flow.name = "f" + + msg = MagicMock() + msg.value.return_value = ToolRequest( + name="search", + parameters='{"query": "test", "limit": 10}', + ) + msg.properties.return_value = {"id": "t5"} + + await svc.on_request(msg, MagicMock(), flow) + + assert received["name"] == "search" + assert received["params"] == {"query": "test", "limit": 10} + + +# --------------------------------------------------------------------------- +# ToolServiceClient tests +# --------------------------------------------------------------------------- + +class TestToolServiceClientCall: + + @pytest.mark.asyncio + async def test_call_sends_request_and_returns_response(self): + """call() should send ToolServiceRequest and return response string.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=None, response="joke result", end_of_stream=True, + )) + + result = await client.call( + user="alice", + config={"style": "pun"}, + arguments={"topic": "cats"}, + ) + + assert result == "joke result" + + req = client.request.call_args[0][0] + assert isinstance(req, ToolServiceRequest) + assert req.user == "alice" + assert json.loads(req.config) == {"style": "pun"} + assert json.loads(req.arguments) == {"topic": "cats"} + + @pytest.mark.asyncio + async def test_call_raises_on_error(self): + """call() should raise RuntimeError when response has error.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=Error(type="tool-service-error", message="service down"), + response="", + )) + + with pytest.raises(RuntimeError, match="service down"): + await client.call(user="u", config={}, arguments={}) + + @pytest.mark.asyncio + async def test_call_empty_config_sends_empty_json(self): + """Empty config/arguments should be sent as '{}'.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=None, response="ok", + )) + + await client.call(user="u", config=None, arguments=None) + + req = client.request.call_args[0][0] + assert req.config == "{}" + assert req.arguments == "{}" + + @pytest.mark.asyncio + async def test_call_passes_timeout(self): + """call() should forward timeout to underlying request.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=None, response="ok", + )) + + await client.call(user="u", config={}, arguments={}, timeout=30) + + _, kwargs = client.request.call_args + assert kwargs["timeout"] == 30 + + +class TestToolServiceClientStreaming: + + @pytest.mark.asyncio + async def test_call_streaming_collects_chunks(self): + """call_streaming should accumulate chunks and return full result.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + + # Simulate streaming: request() calls recipient with each chunk + chunks = [ + ToolServiceResponse(error=None, response="chunk1", end_of_stream=False), + ToolServiceResponse(error=None, response="chunk2", end_of_stream=True), + ] + + async def mock_request(req, timeout=600, recipient=None): + for chunk in chunks: + done = await recipient(chunk) + if done: + break + + client.request = mock_request + + received = [] + + async def callback(text): + received.append(text) + + result = await client.call_streaming( + user="u", config={}, arguments={}, callback=callback, + ) + + assert result == "chunk1chunk2" + assert received == ["chunk1", "chunk2"] + + @pytest.mark.asyncio + async def test_call_streaming_raises_on_error(self): + """call_streaming should raise RuntimeError on error chunk.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + + async def mock_request(req, timeout=600, recipient=None): + error_resp = ToolServiceResponse( + error=Error(type="tool-service-error", message="stream failed"), + response="", + end_of_stream=True, + ) + await recipient(error_resp) + + client.request = mock_request + + with pytest.raises(RuntimeError, match="stream failed"): + await client.call_streaming( + user="u", config={}, arguments={}, + callback=AsyncMock(), + ) + + @pytest.mark.asyncio + async def test_call_streaming_skips_empty_response(self): + """Empty response chunks should not be added to result.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + + chunks = [ + ToolServiceResponse(error=None, response="", end_of_stream=False), + ToolServiceResponse(error=None, response="data", end_of_stream=True), + ] + + async def mock_request(req, timeout=600, recipient=None): + for chunk in chunks: + done = await recipient(chunk) + if done: + break + + client.request = mock_request + + received = [] + + async def callback(text): + received.append(text) + + result = await client.call_streaming( + user="u", config={}, arguments={}, callback=callback, + ) + + # Empty response is falsy, so callback shouldn't be called for it + assert result == "data" + assert received == ["data"] + + +# --------------------------------------------------------------------------- +# Multi-tenancy +# --------------------------------------------------------------------------- + +class TestMultiTenancy: + + @pytest.mark.asyncio + async def test_user_propagated_to_invoke(self): + """User from request should reach the invoke method.""" + from trustgraph.base.dynamic_tool_service import DynamicToolService + + svc = DynamicToolService.__new__(DynamicToolService) + svc.id = "test" + svc.producer = AsyncMock() + + users_seen = [] + + async def tracking(user, config, arguments): + users_seen.append(user) + return "ok" + + svc.invoke = tracking + + if not hasattr(DynamicToolService, "tool_service_metric"): + DynamicToolService.tool_service_metric = MagicMock() + + for u in ["tenant-a", "tenant-b", "tenant-c"]: + msg = MagicMock() + msg.value.return_value = ToolServiceRequest( + user=u, config="{}", arguments="{}", + ) + msg.properties.return_value = {"id": f"req-{u}"} + await svc.on_request(msg, MagicMock(), None) + + assert users_seen == ["tenant-a", "tenant-b", "tenant-c"] + + @pytest.mark.asyncio + async def test_client_sends_user_in_request(self): + """ToolServiceClient.call should include user in request.""" + from trustgraph.base.tool_service_client import ToolServiceClient + + client = ToolServiceClient.__new__(ToolServiceClient) + client.request = AsyncMock(return_value=ToolServiceResponse( + error=None, response="ok", + )) + + await client.call(user="isolated-tenant", config={}, arguments={}) + + req = client.request.call_args[0][0] + assert req.user == "isolated-tenant" diff --git a/tests/unit/test_concurrency/__init__.py b/tests/unit/test_concurrency/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_concurrency/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py new file mode 100644 index 00000000..32a6559b --- /dev/null +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -0,0 +1,286 @@ +""" +Tests for Consumer concurrency: TaskGroup-based concurrent message processing, +rate-limit retry with backpressure, and message acknowledgement. +""" + +import asyncio +import time + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.base.consumer import Consumer +from trustgraph.exceptions import TooManyRequests + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_consumer( + concurrency=1, + handler=None, + rate_limit_retry_time=0.01, + rate_limit_timeout=1, +): + """Create a Consumer with mocked infrastructure.""" + taskgroup = MagicMock() + flow = MagicMock() + backend = MagicMock() + schema = MagicMock() + handler = handler or AsyncMock() + + consumer = Consumer( + taskgroup=taskgroup, + flow=flow, + backend=backend, + topic="test-topic", + subscriber="test-sub", + schema=schema, + handler=handler, + rate_limit_retry_time=rate_limit_retry_time, + rate_limit_timeout=rate_limit_timeout, + concurrency=concurrency, + ) + + return consumer + + +def _make_msg(): + """Create a mock Pulsar message.""" + return MagicMock() + + +# --------------------------------------------------------------------------- +# Concurrency configuration tests +# --------------------------------------------------------------------------- + +class TestConcurrencyConfiguration: + + def test_default_concurrency_is_1(self): + consumer = _make_consumer() + assert consumer.concurrency == 1 + + def test_custom_concurrency(self): + consumer = _make_consumer(concurrency=10) + assert consumer.concurrency == 10 + + def test_concurrency_stored(self): + for n in [1, 5, 20, 100]: + consumer = _make_consumer(concurrency=n) + assert consumer.concurrency == n + + +class TestTaskGroupConcurrency: + + @pytest.mark.asyncio + async def test_creates_n_concurrent_tasks(self): + """consumer_run should create exactly N concurrent consume_from_queue tasks.""" + concurrency = 5 + consumer = _make_consumer(concurrency=concurrency) + + # Track how many consume_from_queue calls are made + call_count = 0 + original_running = True + + async def mock_consume(): + nonlocal call_count + call_count += 1 + # Wait a bit to let all tasks start, then signal stop + await asyncio.sleep(0.05) + consumer.running = False + + consumer.consume_from_queue = mock_consume + + # Mock the backend.create_consumer + consumer.backend.create_consumer = MagicMock(return_value=MagicMock()) + + # Run consumer_run - it will create TaskGroup with N tasks + consumer.running = True + await consumer.consumer_run() + + assert call_count == concurrency + + @pytest.mark.asyncio + async def test_single_concurrency_creates_one_task(self): + """With concurrency=1, only one consume_from_queue task is created.""" + consumer = _make_consumer(concurrency=1) + call_count = 0 + + async def mock_consume(): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + consumer.running = False + + consumer.consume_from_queue = mock_consume + consumer.backend.create_consumer = MagicMock(return_value=MagicMock()) + + consumer.running = True + await consumer.consumer_run() + + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# Rate-limit retry tests +# --------------------------------------------------------------------------- + +class TestRateLimitRetry: + + @pytest.mark.asyncio + async def test_rate_limit_retries_then_succeeds(self): + """TooManyRequests should cause retry, then succeed on next attempt.""" + call_count = 0 + + async def handler_with_retry(msg, consumer_ref, flow): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise TooManyRequests("rate limited") + # Second call succeeds + + consumer = _make_consumer( + handler=handler_with_retry, + rate_limit_retry_time=0.01, + ) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + await consumer.handle_one_from_queue(mock_msg) + + assert call_count == 2 + consumer.consumer.acknowledge.assert_called_once_with(mock_msg) + + @pytest.mark.asyncio + async def test_rate_limit_timeout_negative_acks(self): + """If rate limit retries exhaust the timeout, message is negative-acked.""" + async def always_rate_limited(msg, consumer_ref, flow): + raise TooManyRequests("rate limited") + + consumer = _make_consumer( + handler=always_rate_limited, + rate_limit_retry_time=0.01, + rate_limit_timeout=0.05, + ) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + await consumer.handle_one_from_queue(mock_msg) + + consumer.consumer.negative_acknowledge.assert_called_with(mock_msg) + consumer.consumer.acknowledge.assert_not_called() + + @pytest.mark.asyncio + async def test_non_rate_limit_error_negative_acks_immediately(self): + """Non-TooManyRequests errors should negative-ack immediately (no retry).""" + call_count = 0 + + async def failing_handler(msg, consumer_ref, flow): + nonlocal call_count + call_count += 1 + raise ValueError("bad data") + + consumer = _make_consumer(handler=failing_handler) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + await consumer.handle_one_from_queue(mock_msg) + + assert call_count == 1 + consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg) + + @pytest.mark.asyncio + async def test_successful_message_acknowledged(self): + """Successfully processed messages are acknowledged.""" + consumer = _make_consumer(handler=AsyncMock()) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + await consumer.handle_one_from_queue(mock_msg) + + consumer.consumer.acknowledge.assert_called_once_with(mock_msg) + + +# --------------------------------------------------------------------------- +# Metrics integration +# --------------------------------------------------------------------------- + +class TestMetricsIntegration: + + @pytest.mark.asyncio + async def test_success_metric_on_success(self): + consumer = _make_consumer(handler=AsyncMock()) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + mock_metrics = MagicMock() + mock_metrics.record_time.return_value.__enter__ = MagicMock() + mock_metrics.record_time.return_value.__exit__ = MagicMock() + consumer.metrics = mock_metrics + + await consumer.handle_one_from_queue(mock_msg) + + mock_metrics.process.assert_called_once_with("success") + + @pytest.mark.asyncio + async def test_error_metric_on_failure(self): + async def failing(msg, c, f): + raise ValueError("fail") + + consumer = _make_consumer(handler=failing) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + mock_metrics = MagicMock() + consumer.metrics = mock_metrics + + await consumer.handle_one_from_queue(mock_msg) + + mock_metrics.process.assert_called_once_with("error") + + @pytest.mark.asyncio + async def test_rate_limit_metric_on_too_many_requests(self): + call_count = 0 + + async def handler(msg, c, f): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise TooManyRequests("limited") + + consumer = _make_consumer( + handler=handler, + rate_limit_retry_time=0.01, + ) + mock_msg = _make_msg() + consumer.consumer = MagicMock() + + mock_metrics = MagicMock() + mock_metrics.record_time.return_value.__enter__ = MagicMock() + mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False) + consumer.metrics = mock_metrics + + await consumer.handle_one_from_queue(mock_msg) + + mock_metrics.rate_limit.assert_called_once() + + +# --------------------------------------------------------------------------- +# Stop / running flag +# --------------------------------------------------------------------------- + +class TestStopBehaviour: + + @pytest.mark.asyncio + async def test_stop_sets_running_false(self): + consumer = _make_consumer() + consumer.running = True + + await consumer.stop() + + assert consumer.running is False + + def test_initial_running_state(self): + consumer = _make_consumer() + assert consumer.running is True diff --git a/tests/unit/test_concurrency/test_dispatcher_semaphore.py b/tests/unit/test_concurrency/test_dispatcher_semaphore.py new file mode 100644 index 00000000..6a1ae8ab --- /dev/null +++ b/tests/unit/test_concurrency/test_dispatcher_semaphore.py @@ -0,0 +1,136 @@ +""" +Tests for MessageDispatcher semaphore-based concurrency enforcement. + +Verifies that the dispatcher limits concurrent message processing to +max_workers via asyncio.Semaphore. +""" + +import asyncio + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.rev_gateway.dispatcher import MessageDispatcher + + +class TestSemaphoreEnforcement: + + @pytest.mark.asyncio + async def test_semaphore_limits_concurrent_processing(self): + """Only max_workers messages should be processed concurrently.""" + max_workers = 2 + dispatcher = MessageDispatcher(max_workers=max_workers) + + concurrent_count = 0 + max_concurrent = 0 + processing_event = asyncio.Event() + + async def slow_process(message): + nonlocal concurrent_count, max_concurrent + concurrent_count += 1 + max_concurrent = max(max_concurrent, concurrent_count) + await asyncio.sleep(0.05) + concurrent_count -= 1 + return {"id": message.get("id"), "response": {"ok": True}} + + dispatcher._process_message = slow_process + + # Launch more tasks than max_workers + messages = [ + {"id": f"msg-{i}", "service": "test", "request": {}} + for i in range(5) + ] + + tasks = [ + asyncio.create_task(dispatcher.handle_message(m)) + for m in messages + ] + + await asyncio.gather(*tasks) + + # At no point should more than max_workers have been active + assert max_concurrent <= max_workers + + @pytest.mark.asyncio + async def test_semaphore_value_matches_max_workers(self): + for n in [1, 5, 20]: + dispatcher = MessageDispatcher(max_workers=n) + assert dispatcher.semaphore._value == n + + @pytest.mark.asyncio + async def test_active_tasks_tracked(self): + """Active tasks should be added/removed during processing.""" + dispatcher = MessageDispatcher(max_workers=5) + + task_was_tracked = False + + original_process = dispatcher._process_message + + async def tracking_process(message): + nonlocal task_was_tracked + # During processing, our task should be in active_tasks + if len(dispatcher.active_tasks) > 0: + task_was_tracked = True + return {"id": message.get("id"), "response": {"ok": True}} + + dispatcher._process_message = tracking_process + + await dispatcher.handle_message( + {"id": "test", "service": "test", "request": {}} + ) + + assert task_was_tracked + # After completion, task should be discarded + assert len(dispatcher.active_tasks) == 0 + + @pytest.mark.asyncio + async def test_semaphore_released_on_error(self): + """Semaphore should be released even if processing raises.""" + dispatcher = MessageDispatcher(max_workers=2) + + async def failing_process(message): + raise RuntimeError("process failed") + + dispatcher._process_message = failing_process + + # Should not deadlock — semaphore must be released on error + with pytest.raises(RuntimeError): + await dispatcher.handle_message( + {"id": "test", "service": "test", "request": {}} + ) + + # Semaphore should be back at max + assert dispatcher.semaphore._value == 2 + + @pytest.mark.asyncio + async def test_single_worker_serializes_processing(self): + """With max_workers=1, messages are processed one at a time.""" + dispatcher = MessageDispatcher(max_workers=1) + + order = [] + + async def ordered_process(message): + msg_id = message["id"] + order.append(f"start-{msg_id}") + await asyncio.sleep(0.02) + order.append(f"end-{msg_id}") + return {"id": msg_id, "response": {"ok": True}} + + dispatcher._process_message = ordered_process + + messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)] + tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages] + await asyncio.gather(*tasks) + + # With semaphore=1, each message should complete before next starts + # Check that no two "start" entries appear without an intervening "end" + active = 0 + max_active = 0 + for event in order: + if event.startswith("start"): + active += 1 + max_active = max(max_active, active) + elif event.startswith("end"): + active -= 1 + + assert max_active == 1 diff --git a/tests/unit/test_concurrency/test_graph_rag_concurrency.py b/tests/unit/test_concurrency/test_graph_rag_concurrency.py new file mode 100644 index 00000000..8287427b --- /dev/null +++ b/tests/unit/test_concurrency/test_graph_rag_concurrency.py @@ -0,0 +1,268 @@ +""" +Tests for Graph RAG concurrent query execution. + +Covers: execute_batch_triple_queries concurrent task spawning, +exception handling in gather, and result aggregation. +""" + +import asyncio + +import pytest +from unittest.mock import MagicMock, AsyncMock + +from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_query( + triples_client=None, + entity_limit=50, + triple_limit=30, + max_subgraph_size=1000, + max_path_length=2, +): + """Create a Query object with mocked rag dependencies.""" + rag = MagicMock() + rag.triples_client = triples_client or AsyncMock() + rag.label_cache = LRUCacheWithTTL() + + query = Query( + rag=rag, + user="test-user", + collection="test-collection", + verbose=False, + entity_limit=entity_limit, + triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, + ) + return query + + +def _make_triple(s, p, o): + """Create a simple mock triple.""" + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestBatchTripleQueries: + + @pytest.mark.asyncio + async def test_three_queries_per_entity(self): + """Each entity should generate 3 concurrent queries (s, p, o positions).""" + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[]) + query = _make_query(triples_client=client) + + entities = ["entity-1"] + await query.execute_batch_triple_queries(entities, limit_per_entity=10) + + assert client.query_stream.call_count == 3 + + @pytest.mark.asyncio + async def test_multiple_entities_multiply_queries(self): + """N entities should produce N*3 concurrent queries.""" + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[]) + query = _make_query(triples_client=client) + + entities = ["e1", "e2", "e3"] + await query.execute_batch_triple_queries(entities, limit_per_entity=10) + + assert client.query_stream.call_count == 9 # 3 * 3 + + @pytest.mark.asyncio + async def test_queries_executed_concurrently(self): + """All queries should run concurrently via asyncio.gather.""" + concurrent_count = 0 + max_concurrent = 0 + + async def tracking_query(**kwargs): + nonlocal concurrent_count, max_concurrent + concurrent_count += 1 + max_concurrent = max(max_concurrent, concurrent_count) + await asyncio.sleep(0.02) + concurrent_count -= 1 + return [] + + client = AsyncMock() + client.query_stream = tracking_query + query = _make_query(triples_client=client) + + entities = ["e1", "e2", "e3"] + await query.execute_batch_triple_queries(entities, limit_per_entity=5) + + # All 9 queries should have run concurrently + assert max_concurrent == 9 + + @pytest.mark.asyncio + async def test_results_aggregated(self): + """Results from all queries should be combined into a single list.""" + triple_a = _make_triple("a", "p", "b") + triple_b = _make_triple("c", "p", "d") + + call_count = 0 + + async def alternating_results(**kwargs): + nonlocal call_count + call_count += 1 + if call_count % 2 == 0: + return [triple_a] + return [triple_b] + + client = AsyncMock() + client.query_stream = alternating_results + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # 3 queries, alternating results + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_exception_in_one_query_does_not_block_others(self): + """If one query raises, other results are still collected.""" + good_triple = _make_triple("a", "p", "b") + call_count = 0 + + async def mixed_results(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("query failed") + return [good_triple] + + client = AsyncMock() + client.query_stream = mixed_results + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # 3 queries: 2 succeed, 1 fails → 2 triples + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_none_results_filtered(self): + """None results from queries should be filtered out.""" + call_count = 0 + + async def sometimes_none(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return None + return [_make_triple("a", "p", "b")] + + client = AsyncMock() + client.query_stream = sometimes_none + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # 3 queries: 1 returns None, 2 return triples + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_empty_entities_no_queries(self): + """Empty entity list should produce no queries.""" + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[]) + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries([], limit_per_entity=10) + + assert result == [] + client.query_stream.assert_not_called() + + @pytest.mark.asyncio + async def test_query_params_correct(self): + """Each query should use correct s/p/o positions and params.""" + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[]) + query = _make_query(triples_client=client) + + entities = ["ent-1"] + await query.execute_batch_triple_queries(entities, limit_per_entity=15) + + calls = client.query_stream.call_args_list + assert len(calls) == 3 + + # First call: s=entity, p=None, o=None + assert calls[0].kwargs["s"] == "ent-1" + assert calls[0].kwargs["p"] is None + assert calls[0].kwargs["o"] is None + assert calls[0].kwargs["limit"] == 15 + assert calls[0].kwargs["user"] == "test-user" + assert calls[0].kwargs["collection"] == "test-collection" + assert calls[0].kwargs["batch_size"] == 20 + + # Second call: s=None, p=entity, o=None + assert calls[1].kwargs["s"] is None + assert calls[1].kwargs["p"] == "ent-1" + assert calls[1].kwargs["o"] is None + + # Third call: s=None, p=None, o=entity + assert calls[2].kwargs["s"] is None + assert calls[2].kwargs["p"] is None + assert calls[2].kwargs["o"] == "ent-1" + + +class TestLRUCacheWithTTL: + + def test_put_and_get(self): + cache = LRUCacheWithTTL(max_size=10, ttl=60) + cache.put("key1", "value1") + assert cache.get("key1") == "value1" + + def test_get_missing_returns_none(self): + cache = LRUCacheWithTTL() + assert cache.get("nonexistent") is None + + def test_max_size_eviction(self): + cache = LRUCacheWithTTL(max_size=2, ttl=60) + cache.put("a", 1) + cache.put("b", 2) + cache.put("c", 3) # Should evict "a" + assert cache.get("a") is None + assert cache.get("b") == 2 + assert cache.get("c") == 3 + + def test_lru_order(self): + cache = LRUCacheWithTTL(max_size=2, ttl=60) + cache.put("a", 1) + cache.put("b", 2) + cache.get("a") # Access "a" — now "b" is LRU + cache.put("c", 3) # Should evict "b" + assert cache.get("a") == 1 + assert cache.get("b") is None + assert cache.get("c") == 3 + + def test_ttl_expiration(self): + cache = LRUCacheWithTTL(max_size=10, ttl=0) # TTL=0 means instant expiry + cache.put("key", "value") + # With TTL=0, any time check > 0 means expired + import time + time.sleep(0.01) + assert cache.get("key") is None + + def test_update_existing_key(self): + cache = LRUCacheWithTTL(max_size=10, ttl=60) + cache.put("key", "v1") + cache.put("key", "v2") + assert cache.get("key") == "v2" diff --git a/tests/unit/test_direct/test_entity_centric_write_amplification.py b/tests/unit/test_direct/test_entity_centric_write_amplification.py new file mode 100644 index 00000000..1c9ad1a8 --- /dev/null +++ b/tests/unit/test_direct/test_entity_centric_write_amplification.py @@ -0,0 +1,441 @@ +""" +Tests for entity-centric KG write amplification, delete collection batching, +in-partition filtering, and term type metadata round-trips. + +Complements test_entity_centric_kg.py with deeper verification of the +2-table schema mechanics. +""" + +import pytest +from unittest.mock import MagicMock, patch, call + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_cassandra(): + """Provide mocked Cassandra cluster, session, and BatchStatement.""" + with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cls, \ + patch('trustgraph.direct.cassandra_kg.BatchStatement') as mock_batch_cls: + + mock_cluster = MagicMock() + mock_session = MagicMock() + mock_cluster.connect.return_value = mock_session + mock_cls.return_value = mock_cluster + + # Track batch.add calls per batch instance + batches = [] + + def make_batch(): + batch = MagicMock() + batch._adds = [] + original_add = batch.add + + def tracking_add(stmt, params): + batch._adds.append((stmt, params)) + + batch.add = tracking_add + batches.append(batch) + return batch + + mock_batch_cls.side_effect = make_batch + + yield { + "cluster_cls": mock_cls, + "cluster": mock_cluster, + "session": mock_session, + "batch_cls": mock_batch_cls, + "batches": batches, + } + + +@pytest.fixture +def entity_kg(mock_cassandra): + """Create an EntityCentricKnowledgeGraph with mocked Cassandra.""" + from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph + kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_ks') + return kg, mock_cassandra + + +# --------------------------------------------------------------------------- +# Write amplification: row count verification +# --------------------------------------------------------------------------- + +class TestWriteAmplification: + + def test_uri_object_produces_4_entity_rows_plus_collection(self, entity_kg): + """URI object → S + P + O + G-if-non-default entity rows + 1 collection row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/Alice', + p='http://ex.org/knows', + o='http://ex.org/Bob', + g='http://ex.org/g1', + otype='u', + ) + + # Should be exactly one batch + assert len(ctx["batches"]) == 1 + batch = ctx["batches"][0] + + # 4 entity rows (S, P, O, G) + 1 collection row = 5 + assert len(batch._adds) == 5 + + # Check roles assigned + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'S' in roles + assert 'P' in roles + assert 'O' in roles + assert 'G' in roles + + def test_literal_object_produces_3_entity_rows(self, entity_kg): + """Literal object → S + P entity rows (no O row) + collection row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/Alice', + p='http://ex.org/name', + o='Alice Smith', + g=None, # default graph + otype='l', + ) + + batch = ctx["batches"][0] + + # S + P entity rows + 1 collection = 3 (no O row for literal, no G for default) + assert len(batch._adds) == 3 + + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'S' in roles + assert 'P' in roles + assert 'O' not in roles + assert 'G' not in roles + + def test_triple_otype_gets_object_entity_row(self, entity_kg): + """otype='t' (quoted triple) → object gets entity row like URI.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='{"s":{},"p":{},"o":{}}', + g=None, + otype='t', + ) + + batch = ctx["batches"][0] + + # S + P + O entity rows + collection = 4 (no G for default graph) + assert len(batch._adds) == 4 + + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'O' in roles + + def test_default_graph_no_g_row(self, entity_kg): + """Default graph (g=None) → no G entity row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='http://ex.org/o', + g=None, + otype='u', + ) + + batch = ctx["batches"][0] + + # S + P + O entity rows + collection = 4 (no G) + assert len(batch._adds) == 4 + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'G' not in roles + + def test_non_default_graph_gets_g_row(self, entity_kg): + """Non-default graph → gets G entity row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='http://ex.org/o', + g='http://ex.org/graph1', + otype='u', + ) + + batch = ctx["batches"][0] + + # S + P + O + G entity rows + collection = 5 + assert len(batch._adds) == 5 + roles = [params[2] for _, params in batch._adds if len(params) == 10] + assert 'G' in roles + + def test_dtype_and_lang_passed_to_all_rows(self, entity_kg): + """dtype and lang should be stored in every entity row.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/label', + o='thing', + g=None, + otype='l', + dtype='xsd:string', + lang='en', + ) + + batch = ctx["batches"][0] + + # Check entity rows carry dtype and lang + for _, params in batch._adds: + if len(params) == 10: + # Entity row: (collection, entity, role, p, otype, s, o, d, dtype, lang) + assert params[8] == 'xsd:string' + assert params[9] == 'en' + + +# --------------------------------------------------------------------------- +# In-partition filtering: get_os, get_spo +# --------------------------------------------------------------------------- + +class TestInPartitionFiltering: + + def test_get_os_filters_by_object(self, entity_kg): + """get_os should filter results by matching object value.""" + kg, ctx = entity_kg + + # Simulate rows returned from subject partition (all have same s) + mock_rows = [ + MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob', + d='', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + MagicMock(p='http://ex.org/likes', o='http://ex.org/Charlie', + d='', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_os('col', 'http://ex.org/Bob', 'http://ex.org/Alice') + + # Only the Bob row should pass the filter + assert len(results) == 1 + assert results[0].o == 'http://ex.org/Bob' + assert results[0].p == 'http://ex.org/knows' + + def test_get_os_returns_empty_when_no_match(self, entity_kg): + """get_os should return empty list when object doesn't match any row.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob', + d='', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_os('col', 'http://ex.org/Charlie', 'http://ex.org/Alice') + + assert len(results) == 0 + + def test_get_spo_filters_by_object(self, entity_kg): + """get_spo should filter results by matching object value.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(o='http://ex.org/Bob', d='', otype='u', dtype='', lang=''), + MagicMock(o='http://ex.org/Charlie', d='', otype='u', dtype='', lang=''), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_spo( + 'col', 'http://ex.org/Alice', 'http://ex.org/knows', + 'http://ex.org/Bob', + ) + + assert len(results) == 1 + assert results[0].o == 'http://ex.org/Bob' + + def test_get_os_with_graph_filter(self, entity_kg): + """get_os with specific graph should filter both object and graph.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob', + d='http://ex.org/g1', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob', + d='http://ex.org/g2', otype='u', dtype='', lang='', + s='http://ex.org/Alice'), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_os( + 'col', 'http://ex.org/Bob', 'http://ex.org/Alice', + g='http://ex.org/g1', + ) + + assert len(results) == 1 + assert results[0].g == 'http://ex.org/g1' + + +# --------------------------------------------------------------------------- +# Delete collection batching +# --------------------------------------------------------------------------- + +class TestDeleteCollectionBatching: + + def test_extracts_unique_entities_from_quads(self, entity_kg): + """delete_collection should extract s, p, and URI o as entities.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(d='', s='http://ex.org/A', p='http://ex.org/knows', + o='http://ex.org/B', otype='u', dtype='', lang=''), + MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name', + o='Alice', otype='l', dtype='', lang=''), + ] + ctx["session"].execute.return_value = mock_rows + ctx["batches"].clear() + + kg.delete_collection('col') + + # Unique entities: A, knows, B, name (literal 'Alice' excluded) + # The batches should include entity partition deletes + all_adds = [] + for batch in ctx["batches"]: + all_adds.extend(batch._adds) + + # We expect entity deletes + collection row deletes + metadata delete + # Just verify the function completes and calls execute + assert ctx["session"].execute.called + + def test_literal_objects_not_treated_as_entities(self, entity_kg): + """Literal objects (otype='l') should not get entity partition deletes.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name', + o='Alice', otype='l', dtype='', lang=''), + ] + ctx["session"].execute.return_value = mock_rows + ctx["batches"].clear() + + kg.delete_collection('col') + + # Entity partition deletes should only include A and name, not Alice + entity_deletes = [] + for batch in ctx["batches"]: + for _, params in batch._adds: + if len(params) == 2: # delete_entity_partition takes (collection, entity) + entity_deletes.append(params[1]) + + assert 'http://ex.org/A' in entity_deletes + assert 'http://ex.org/name' in entity_deletes + assert 'Alice' not in entity_deletes + + def test_non_default_graph_treated_as_entity(self, entity_kg): + """Non-default graphs should get entity partition deletes.""" + kg, ctx = entity_kg + + mock_rows = [ + MagicMock(d='http://ex.org/g1', s='http://ex.org/A', + p='http://ex.org/p', o='http://ex.org/B', + otype='u', dtype='', lang=''), + ] + ctx["session"].execute.return_value = mock_rows + ctx["batches"].clear() + + kg.delete_collection('col') + + entity_deletes = [] + for batch in ctx["batches"]: + for _, params in batch._adds: + if len(params) == 2: + entity_deletes.append(params[1]) + + assert 'http://ex.org/g1' in entity_deletes + + def test_empty_collection_delete_completes(self, entity_kg): + """Deleting an empty collection should not error.""" + kg, ctx = entity_kg + + ctx["session"].execute.return_value = [] + ctx["batches"].clear() + + # Should not raise + kg.delete_collection('empty-col') + + +# --------------------------------------------------------------------------- +# Term type metadata round-trip +# --------------------------------------------------------------------------- + +class TestTermTypeMetadata: + + def test_query_results_include_otype(self, entity_kg): + """Query results should include otype from Cassandra rows.""" + kg, ctx = entity_kg + from trustgraph.direct.cassandra_kg import QuadResult + + mock_rows = [ + MagicMock(p='http://ex.org/name', o='Alice', + d='', otype='l', dtype='xsd:string', lang='en', + s='http://ex.org/Alice'), + ] + ctx["session"].execute.return_value = mock_rows + + results = kg.get_s('col', 'http://ex.org/Alice') + + assert len(results) == 1 + assert results[0].otype == 'l' + assert results[0].dtype == 'xsd:string' + assert results[0].lang == 'en' + + def test_auto_detect_otype_uri(self, entity_kg): + """Auto-detect should classify http:// as URI.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='http://ex.org/o', + ) + + batch = ctx["batches"][0] + # Check otype in entity rows (position 4) + for _, params in batch._adds: + if len(params) == 10: + assert params[4] == 'u' + + def test_auto_detect_otype_literal(self, entity_kg): + """Auto-detect should classify non-http:// as literal.""" + kg, ctx = entity_kg + ctx["batches"].clear() + + kg.insert( + collection='col', + s='http://ex.org/s', + p='http://ex.org/p', + o='plain text', + ) + + batch = ctx["batches"][0] + for _, params in batch._adds: + if len(params) == 10: + assert params[4] == 'l' diff --git a/tests/unit/test_embeddings/test_document_embeddings_processor.py b/tests/unit/test_embeddings/test_document_embeddings_processor.py new file mode 100644 index 00000000..9cd93c4f --- /dev/null +++ b/tests/unit/test_embeddings/test_document_embeddings_processor.py @@ -0,0 +1,164 @@ +""" +Tests for document embeddings processor — single-chunk embedding via batch API. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.embeddings.document_embeddings.embeddings import Processor +from trustgraph.schema import ( + Chunk, DocumentEmbeddings, ChunkEmbeddings, + EmbeddingsRequest, EmbeddingsResponse, Metadata, +) + + +@pytest.fixture +def processor(): + return Processor( + taskgroup=AsyncMock(), + id="test-doc-embeddings", + ) + + +def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", + user="test", collection="default"): + metadata = Metadata(id=doc_id, user=user, collection=collection) + value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id) + msg = MagicMock() + msg.value.return_value = value + return msg + + +class TestDocumentEmbeddingsProcessor: + + @pytest.mark.asyncio + async def test_sends_single_text_as_list(self, processor): + """Document embeddings should wrap single chunk in a list for the API.""" + msg = _make_chunk_message("test chunk text") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.1, 0.2, 0.3]] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + # Should send EmbeddingsRequest with texts=[chunk] + mock_request.assert_called_once() + req = mock_request.call_args[0][0] + assert isinstance(req, EmbeddingsRequest) + assert req.texts == ["test chunk text"] + + @pytest.mark.asyncio + async def test_extracts_first_vector(self, processor): + """Should use vectors[0] from the response.""" + msg = _make_chunk_message("chunk") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[1.0, 2.0, 3.0]] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + result = mock_output.send.call_args[0][0] + assert isinstance(result, DocumentEmbeddings) + assert len(result.chunks) == 1 + assert result.chunks[0].vector == [1.0, 2.0, 3.0] + + @pytest.mark.asyncio + async def test_empty_vectors_response(self, processor): + """Should handle empty vectors response gracefully.""" + msg = _make_chunk_message("chunk") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + result = mock_output.send.call_args[0][0] + assert result.chunks[0].vector == [] + + @pytest.mark.asyncio + async def test_chunk_id_is_document_id(self, processor): + """ChunkEmbeddings should use document_id as chunk_id.""" + msg = _make_chunk_message(doc_id="my-doc-42") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.0]] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + result = mock_output.send.call_args[0][0] + assert result.chunks[0].chunk_id == "my-doc-42" + + @pytest.mark.asyncio + async def test_metadata_preserved(self, processor): + """Output should carry the original metadata.""" + msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1") + + mock_request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.0]] + )) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + result = mock_output.send.call_args[0][0] + assert result.metadata.user == "alice" + assert result.metadata.collection == "reports" + assert result.metadata.id == "d1" + + @pytest.mark.asyncio + async def test_error_propagates(self, processor): + """Embedding errors should propagate for retry.""" + msg = _make_chunk_message() + + mock_request = AsyncMock(side_effect=RuntimeError("service down")) + + def flow(name): + if name == "embeddings-request": + return MagicMock(request=mock_request) + return MagicMock() + + with pytest.raises(RuntimeError, match="service down"): + await processor.on_message(msg, MagicMock(), flow) diff --git a/tests/unit/test_embeddings/test_embeddings_client.py b/tests/unit/test_embeddings/test_embeddings_client.py new file mode 100644 index 00000000..84305d7a --- /dev/null +++ b/tests/unit/test_embeddings/test_embeddings_client.py @@ -0,0 +1,109 @@ +""" +Tests for EmbeddingsClient — the client interface for batch embeddings. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.base.embeddings_client import EmbeddingsClient +from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error + + +class TestEmbeddingsClient: + + @pytest.mark.asyncio + async def test_embed_sends_request_and_returns_vectors(self): + """embed() should send an EmbeddingsRequest and return vectors.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, + vectors=[[0.1, 0.2], [0.3, 0.4]], + )) + + result = await client.embed(texts=["hello", "world"]) + + assert result == [[0.1, 0.2], [0.3, 0.4]] + client.request.assert_called_once() + req = client.request.call_args[0][0] + assert isinstance(req, EmbeddingsRequest) + assert req.texts == ["hello", "world"] + + @pytest.mark.asyncio + async def test_embed_single_text(self): + """embed() should work with a single text.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, + vectors=[[1.0, 2.0, 3.0]], + )) + + result = await client.embed(texts=["single"]) + + assert result == [[1.0, 2.0, 3.0]] + + @pytest.mark.asyncio + async def test_embed_raises_on_error_response(self): + """embed() should raise RuntimeError when response contains an error.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=Error(type="embeddings-error", message="model not found"), + vectors=[], + )) + + with pytest.raises(RuntimeError, match="model not found"): + await client.embed(texts=["test"]) + + @pytest.mark.asyncio + async def test_embed_passes_timeout(self): + """embed() should pass timeout to the underlying request.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.0]], + )) + + await client.embed(texts=["test"], timeout=60) + + _, kwargs = client.request.call_args + assert kwargs["timeout"] == 60 + + @pytest.mark.asyncio + async def test_embed_default_timeout(self): + """embed() should use 300s default timeout.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[[0.0]], + )) + + await client.embed(texts=["test"]) + + _, kwargs = client.request.call_args + assert kwargs["timeout"] == 300 + + @pytest.mark.asyncio + async def test_embed_empty_texts(self): + """embed() with empty list should still make the request.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=[], + )) + + result = await client.embed(texts=[]) + + assert result == [] + + @pytest.mark.asyncio + async def test_embed_large_batch(self): + """embed() should handle large batches.""" + client = EmbeddingsClient.__new__(EmbeddingsClient) + n = 100 + vectors = [[float(i)] for i in range(n)] + client.request = AsyncMock(return_value=EmbeddingsResponse( + error=None, vectors=vectors, + )) + + texts = [f"text {i}" for i in range(n)] + result = await client.embed(texts=texts) + + assert len(result) == n + req = client.request.call_args[0][0] + assert len(req.texts) == n diff --git a/tests/unit/test_embeddings/test_embeddings_service_request.py b/tests/unit/test_embeddings/test_embeddings_service_request.py new file mode 100644 index 00000000..c57fae16 --- /dev/null +++ b/tests/unit/test_embeddings/test_embeddings_service_request.py @@ -0,0 +1,135 @@ +""" +Tests for EmbeddingsService.on_request — the request handler that dispatches +to on_embeddings and sends responses. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.base import EmbeddingsService +from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error +from trustgraph.exceptions import TooManyRequests + + +class StubEmbeddingsService(EmbeddingsService): + """Minimal concrete implementation for testing on_request.""" + + def __init__(self, embed_result=None, embed_error=None): + # Skip super().__init__ to avoid taskgroup/registration + self.embed_result = embed_result or [[0.1, 0.2]] + self.embed_error = embed_error + + async def on_embeddings(self, texts, model=None): + if self.embed_error: + raise self.embed_error + return self.embed_result + + +def _make_msg(texts, msg_id="req-1"): + request = EmbeddingsRequest(texts=texts) + msg = MagicMock() + msg.value.return_value = request + msg.properties.return_value = {"id": msg_id} + return msg + + +def _make_flow(model="test-model"): + mock_response_producer = AsyncMock() + mock_flow = MagicMock() + + def flow_callable(name): + if name == "model": + return model + if name == "response": + return mock_response_producer + return MagicMock() + + flow_callable.producer = {"response": mock_response_producer} + return flow_callable, mock_response_producer + + +class TestEmbeddingsServiceOnRequest: + + @pytest.mark.asyncio + async def test_successful_request(self): + """on_request should call on_embeddings and send response.""" + service = StubEmbeddingsService(embed_result=[[0.1, 0.2], [0.3, 0.4]]) + msg = _make_msg(["hello", "world"], msg_id="r1") + flow, mock_response = _make_flow(model="my-model") + + await service.on_request(msg, MagicMock(), flow) + + mock_response.send.assert_called_once() + resp = mock_response.send.call_args[0][0] + assert isinstance(resp, EmbeddingsResponse) + assert resp.error is None + assert resp.vectors == [[0.1, 0.2], [0.3, 0.4]] + + # Check id is passed through + props = mock_response.send.call_args[1]["properties"] + assert props["id"] == "r1" + + @pytest.mark.asyncio + async def test_passes_model_from_flow(self): + """on_request should pass model parameter from flow to on_embeddings.""" + calls = [] + + class TrackingService(EmbeddingsService): + def __init__(self): + pass + + async def on_embeddings(self, texts, model=None): + calls.append({"texts": texts, "model": model}) + return [[0.0]] + + service = TrackingService() + msg = _make_msg(["test"]) + flow, _ = _make_flow(model="custom-model-v2") + + await service.on_request(msg, MagicMock(), flow) + + assert len(calls) == 1 + assert calls[0]["model"] == "custom-model-v2" + assert calls[0]["texts"] == ["test"] + + @pytest.mark.asyncio + async def test_error_sends_error_response(self): + """Non-rate-limit errors should send an error response.""" + service = StubEmbeddingsService( + embed_error=ValueError("dimension mismatch") + ) + msg = _make_msg(["test"], msg_id="r2") + flow, mock_response = _make_flow() + + await service.on_request(msg, MagicMock(), flow) + + mock_response.send.assert_called_once() + resp = mock_response.send.call_args[0][0] + assert resp.error is not None + assert resp.error.type == "embeddings-error" + assert "dimension mismatch" in resp.error.message + assert resp.vectors == [] + + @pytest.mark.asyncio + async def test_rate_limit_propagates(self): + """TooManyRequests should propagate (not caught as error response).""" + service = StubEmbeddingsService( + embed_error=TooManyRequests("rate limited") + ) + msg = _make_msg(["test"]) + flow, _ = _make_flow() + + with pytest.raises(TooManyRequests): + await service.on_request(msg, MagicMock(), flow) + + @pytest.mark.asyncio + async def test_message_id_preserved(self): + """The request message id should be forwarded in the response properties.""" + service = StubEmbeddingsService() + msg = _make_msg(["test"], msg_id="unique-id-42") + flow, mock_response = _make_flow() + + await service.on_request(msg, MagicMock(), flow) + + props = mock_response.send.call_args[1]["properties"] + assert props["id"] == "unique-id-42" diff --git a/tests/unit/test_embeddings/test_graph_embeddings_processor.py b/tests/unit/test_embeddings/test_graph_embeddings_processor.py new file mode 100644 index 00000000..5d535349 --- /dev/null +++ b/tests/unit/test_embeddings/test_graph_embeddings_processor.py @@ -0,0 +1,233 @@ +""" +Tests for graph embeddings processor — batch embedding of entity contexts. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.embeddings.graph_embeddings.embeddings import Processor +from trustgraph.schema import ( + EntityContexts, EntityEmbeddings, GraphEmbeddings, + Term, IRI, Metadata, +) + + +@pytest.fixture +def processor(): + return Processor( + taskgroup=AsyncMock(), + id="test-graph-embeddings", + batch_size=3, + ) + + +def _make_entity_context(name, context, chunk_id="chunk-1"): + """Create an entity context for testing.""" + entity = Term(type=IRI, iri=f"urn:entity:{name}") + return MagicMock(entity=entity, context=context, chunk_id=chunk_id) + + +def _make_message(entities, doc_id="doc-1", user="test", collection="default"): + metadata = Metadata(id=doc_id, user=user, collection=collection) + value = EntityContexts(metadata=metadata, entities=entities) + msg = MagicMock() + msg.value.return_value = value + return msg + + +class TestGraphEmbeddingsInit: + + def test_default_batch_size(self): + p = Processor(taskgroup=AsyncMock(), id="test") + assert p.batch_size == 5 + + def test_custom_batch_size(self): + p = Processor(taskgroup=AsyncMock(), id="test", batch_size=20) + assert p.batch_size == 20 + + +class TestGraphEmbeddingsBatchProcessing: + + @pytest.mark.asyncio + async def test_single_batch_call_for_all_entities(self, processor): + """All entity contexts should be embedded in a single API call.""" + entities = [ + _make_entity_context("Alice", "Alice is a person"), + _make_entity_context("Bob", "Bob is a developer"), + _make_entity_context("Acme", "Acme is a company"), + ] + msg = _make_message(entities) + + mock_embed = AsyncMock(return_value=[ + [0.1, 0.2], [0.3, 0.4], [0.5, 0.6], + ]) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + # Single batch call with all three texts + mock_embed.assert_called_once_with( + texts=["Alice is a person", "Bob is a developer", "Acme is a company"] + ) + + @pytest.mark.asyncio + async def test_vectors_paired_with_correct_entities(self, processor): + """Each vector should be paired with its corresponding entity.""" + entities = [ + _make_entity_context("Alice", "ctx-A", chunk_id="c1"), + _make_entity_context("Bob", "ctx-B", chunk_id="c2"), + ] + msg = _make_message(entities) + + vectors = [[1.0, 2.0], [3.0, 4.0]] + mock_embed = AsyncMock(return_value=vectors) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + # With batch_size=3, all 2 entities fit in one output message + mock_output.send.assert_called_once() + result = mock_output.send.call_args[0][0] + assert isinstance(result, GraphEmbeddings) + assert len(result.entities) == 2 + assert result.entities[0].vector == [1.0, 2.0] + assert result.entities[0].entity.iri == "urn:entity:Alice" + assert result.entities[0].chunk_id == "c1" + assert result.entities[1].vector == [3.0, 4.0] + assert result.entities[1].entity.iri == "urn:entity:Bob" + + @pytest.mark.asyncio + async def test_output_batching(self, processor): + """Output should be split into batches of batch_size.""" + # batch_size=3, 7 entities -> 3 output messages (3+3+1) + entities = [ + _make_entity_context(f"E{i}", f"context {i}") + for i in range(7) + ] + msg = _make_message(entities) + + vectors = [[float(i)] for i in range(7)] + mock_embed = AsyncMock(return_value=vectors) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + assert mock_output.send.call_count == 3 + # First batch has 3 entities + batch1 = mock_output.send.call_args_list[0][0][0] + assert len(batch1.entities) == 3 + # Second batch has 3 entities + batch2 = mock_output.send.call_args_list[1][0][0] + assert len(batch2.entities) == 3 + # Third batch has 1 entity + batch3 = mock_output.send.call_args_list[2][0][0] + assert len(batch3.entities) == 1 + + @pytest.mark.asyncio + async def test_output_batches_preserve_metadata(self, processor): + """Each output batch should carry the original metadata.""" + entities = [ + _make_entity_context(f"E{i}", f"ctx {i}") + for i in range(5) + ] + msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main") + + mock_embed = AsyncMock(return_value=[[0.0]] * 5) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + for call in mock_output.send.call_args_list: + result = call[0][0] + assert result.metadata.id == "doc-42" + assert result.metadata.user == "alice" + assert result.metadata.collection == "main" + + @pytest.mark.asyncio + async def test_single_entity(self, processor): + """Single entity should work with one embed call and one output.""" + entities = [_make_entity_context("Solo", "solo context")] + msg = _make_message(entities) + + mock_embed = AsyncMock(return_value=[[1.0, 2.0, 3.0]]) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + mock_embed.assert_called_once_with(texts=["solo context"]) + mock_output.send.assert_called_once() + + @pytest.mark.asyncio + async def test_embed_error_propagates(self, processor): + """Embedding service errors should propagate for retry.""" + entities = [_make_entity_context("E", "ctx")] + msg = _make_message(entities) + + mock_embed = AsyncMock(side_effect=RuntimeError("embedding failed")) + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + return MagicMock() + + with pytest.raises(RuntimeError, match="embedding failed"): + await processor.on_message(msg, MagicMock(), flow) + + @pytest.mark.asyncio + async def test_exact_batch_size(self, processor): + """When entity count equals batch_size, exactly one output message.""" + entities = [ + _make_entity_context(f"E{i}", f"ctx {i}") + for i in range(3) # batch_size=3 + ] + msg = _make_message(entities) + + mock_embed = AsyncMock(return_value=[[0.0]] * 3) + mock_output = AsyncMock() + + def flow(name): + if name == "embeddings-request": + return MagicMock(embed=mock_embed) + elif name == "output": + return mock_output + return MagicMock() + + await processor.on_message(msg, MagicMock(), flow) + + mock_output.send.assert_called_once() + assert len(mock_output.send.call_args[0][0].entities) == 3 diff --git a/tests/unit/test_extract/test_streaming_triples/__init__.py b/tests/unit/test_extract/test_streaming_triples/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_extract/test_streaming_triples/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py new file mode 100644 index 00000000..b651b59e --- /dev/null +++ b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py @@ -0,0 +1,407 @@ +""" +Tests for streaming triple and entity context batching in the definitions +KG extractor. + +Covers: triples batch splitting, entity context batch splitting, +metadata preservation, provenance, and empty/null filtering. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.extract.kg.definitions.extract import ( + Processor, default_triples_batch_size, default_entity_batch_size, +) +from trustgraph.schema import ( + Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_processor(triples_batch_size=default_triples_batch_size, + entity_batch_size=default_entity_batch_size): + proc = Processor.__new__(Processor) + proc.triples_batch_size = triples_batch_size + proc.entity_batch_size = entity_batch_size + return proc + + +def _make_defn(entity, definition): + return {"entity": entity, "definition": definition} + + +def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", + user="user-1", collection="col-1", document_id=""): + chunk = Chunk( + metadata=Metadata( + id=meta_id, root=root, user=user, collection=collection, + ), + chunk=text.encode("utf-8"), + document_id=document_id, + ) + msg = MagicMock() + msg.value.return_value = chunk + return msg + + +def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"): + mock_triples_pub = AsyncMock() + mock_ecs_pub = AsyncMock() + mock_prompt_client = AsyncMock() + mock_prompt_client.extract_definitions = AsyncMock( + return_value=prompt_result + ) + + def flow(name): + if name == "prompt-request": + return mock_prompt_client + if name == "triples": + return mock_triples_pub + if name == "entity-contexts": + return mock_ecs_pub + if name == "llm-model": + return llm_model + if name == "ontology": + return ontology_uri + return MagicMock() + + return flow, mock_triples_pub, mock_ecs_pub, mock_prompt_client + + +def _sent_triples(mock_pub): + return [call.args[0] for call in mock_pub.send.call_args_list] + + +def _sent_ecs(mock_pub): + return [call.args[0] for call in mock_pub.send.call_args_list] + + +def _all_triples_flat(mock_pub): + result = [] + for triples_msg in _sent_triples(mock_pub): + result.extend(triples_msg.triples) + return result + + +def _all_entities_flat(mock_pub): + result = [] + for ecs_msg in _sent_ecs(mock_pub): + result.extend(ecs_msg.entities) + return result + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestDefaults: + + def test_default_triples_batch_size(self): + assert default_triples_batch_size == 50 + + def test_default_entity_batch_size(self): + assert default_entity_batch_size == 5 + + +class TestTriplesBatching: + + @pytest.mark.asyncio + async def test_single_batch_when_under_limit(self): + proc = _make_processor(triples_batch_size=100) + defs = [_make_defn("Cat", "A feline animal")] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 1 + + @pytest.mark.asyncio + async def test_multiple_triples_batches(self): + proc = _make_processor(triples_batch_size=2) + defs = [ + _make_defn("Cat", "A feline"), + _make_defn("Dog", "A canine"), + ] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + # 2 defs → 2 labels + 2 definitions = 4 triples + provenance + # With batch_size=2, should produce multiple batches + assert triples_pub.send.call_count > 1 + + @pytest.mark.asyncio + async def test_triples_batch_sizes_within_limit(self): + batch_size = 3 + proc = _make_processor(triples_batch_size=batch_size) + defs = [ + _make_defn("A", "def A"), + _make_defn("B", "def B"), + _make_defn("C", "def C"), + ] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + for triples_msg in _sent_triples(triples_pub): + assert len(triples_msg.triples) <= batch_size + + +class TestEntityContextBatching: + + @pytest.mark.asyncio + async def test_single_entity_batch_when_under_limit(self): + proc = _make_processor(entity_batch_size=100) + defs = [_make_defn("Cat", "A feline")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + # 1 def → 2 entity contexts (name + definition) + assert ecs_pub.send.call_count == 1 + + @pytest.mark.asyncio + async def test_multiple_entity_batches(self): + proc = _make_processor(entity_batch_size=2) + defs = [ + _make_defn("Cat", "A feline"), + _make_defn("Dog", "A canine"), + ] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + # 2 defs → 4 entity contexts, batch_size=2 → 2 batches + assert ecs_pub.send.call_count == 2 + + @pytest.mark.asyncio + async def test_entity_batch_sizes_within_limit(self): + batch_size = 3 + proc = _make_processor(entity_batch_size=batch_size) + defs = [ + _make_defn("A", "def A"), + _make_defn("B", "def B"), + _make_defn("C", "def C"), + ] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + for ecs_msg in _sent_ecs(ecs_pub): + assert len(ecs_msg.entities) <= batch_size + + @pytest.mark.asyncio + async def test_entity_contexts_have_name_and_definition(self): + """Each definition produces 2 entity contexts: name and definition.""" + proc = _make_processor(entity_batch_size=100) + defs = [_make_defn("Cat", "A feline animal")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + entities = _all_entities_flat(ecs_pub) + assert len(entities) == 2 + contexts = {e.context for e in entities} + assert "Cat" in contexts + assert "A feline animal" in contexts + + +class TestMetadataPreservation: + + @pytest.mark.asyncio + async def test_triples_metadata(self): + proc = _make_processor(triples_batch_size=2) + defs = [_make_defn("X", "def X")] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg( + "text", meta_id="c-1", root="r-1", + user="u-1", collection="coll-1", + ) + + await proc.on_message(msg, MagicMock(), flow) + + for triples_msg in _sent_triples(triples_pub): + assert triples_msg.metadata.id == "c-1" + assert triples_msg.metadata.root == "r-1" + assert triples_msg.metadata.user == "u-1" + assert triples_msg.metadata.collection == "coll-1" + + @pytest.mark.asyncio + async def test_entity_contexts_metadata(self): + proc = _make_processor(entity_batch_size=1) + defs = [_make_defn("X", "def X")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg( + "text", meta_id="c-2", root="r-2", + user="u-2", collection="coll-2", + ) + + await proc.on_message(msg, MagicMock(), flow) + + for ecs_msg in _sent_ecs(ecs_pub): + assert ecs_msg.metadata.id == "c-2" + assert ecs_msg.metadata.root == "r-2" + + +class TestEmptyAndNullFiltering: + + @pytest.mark.asyncio + async def test_empty_entity_skipped(self): + proc = _make_processor() + defs = [ + _make_defn("", "some definition"), + _make_defn("Valid", "a valid definition"), + ] + flow, triples_pub, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(triples_pub) + all_e = _all_entities_flat(ecs_pub) + # Only "Valid" should be present + entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")} + assert any("valid" in iri for iri in entity_iris) + assert len(all_e) == 2 # name + definition for "Valid" only + + @pytest.mark.asyncio + async def test_empty_definition_skipped(self): + proc = _make_processor() + defs = [ + _make_defn("Entity", ""), + _make_defn("Good", "good definition"), + ] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(triples_pub) + entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")} + assert any("good" in iri for iri in entity_iris) + # "Entity" with empty def should have been skipped + assert not any("entity" in iri and "good" not in iri for iri in entity_iris) + + @pytest.mark.asyncio + async def test_none_fields_skipped(self): + proc = _make_processor() + defs = [ + _make_defn(None, "some definition"), + _make_defn("Entity", None), + ] + flow, triples_pub, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_all_filtered_no_output(self): + proc = _make_processor() + defs = [_make_defn("", ""), _make_defn(None, None)] + flow, triples_pub, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_empty_prompt_response(self): + proc = _make_processor() + flow, triples_pub, ecs_pub, _ = _make_flow([]) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + +class TestProvenanceInclusion: + + @pytest.mark.asyncio + async def test_provenance_triples_present(self): + proc = _make_processor(triples_batch_size=200) + defs = [_make_defn("Cat", "A feline")] + flow, triples_pub, _, _ = _make_flow(defs) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(triples_pub) + # 1 def → 1 label + 1 definition = 2 content triples + # Provenance adds more + assert len(all_t) > 2 + + +class TestErrorHandling: + + @pytest.mark.asyncio + async def test_prompt_error_caught(self): + proc = _make_processor() + flow, triples_pub, ecs_pub, prompt = _make_flow([]) + prompt.extract_definitions = AsyncMock( + side_effect=RuntimeError("LLM error") + ) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_non_list_response_caught(self): + proc = _make_processor() + flow, triples_pub, ecs_pub, prompt = _make_flow("not a list") + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert triples_pub.send.call_count == 0 + assert ecs_pub.send.call_count == 0 + + +class TestDocumentIdProvenance: + + @pytest.mark.asyncio + async def test_document_id_used_for_chunk_id(self): + """When document_id is set, entity contexts should use it as chunk_id.""" + proc = _make_processor(entity_batch_size=100) + defs = [_make_defn("Cat", "A feline")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text", document_id="doc-123") + + await proc.on_message(msg, MagicMock(), flow) + + entities = _all_entities_flat(ecs_pub) + for e in entities: + assert e.chunk_id == "doc-123" + + @pytest.mark.asyncio + async def test_metadata_id_fallback_for_chunk_id(self): + """When document_id is empty, metadata.id is used as chunk_id.""" + proc = _make_processor(entity_batch_size=100) + defs = [_make_defn("Cat", "A feline")] + flow, _, ecs_pub, _ = _make_flow(defs) + msg = _make_chunk_msg("text", meta_id="chunk-42", document_id="") + + await proc.on_message(msg, MagicMock(), flow) + + entities = _all_entities_flat(ecs_pub) + for e in entities: + assert e.chunk_id == "chunk-42" diff --git a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py new file mode 100644 index 00000000..cf3b1fb0 --- /dev/null +++ b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py @@ -0,0 +1,408 @@ +""" +Tests for streaming triple batching in the relationships KG extractor. + +Covers: batch size configuration, output splitting, metadata preservation, +provenance inclusion, empty/null filtering, and error propagation. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from trustgraph.extract.kg.relationships.extract import ( + Processor, default_triples_batch_size, +) +from trustgraph.schema import ( + Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_processor(triples_batch_size=default_triples_batch_size): + """Create a Processor without triggering FlowProcessor.__init__.""" + proc = Processor.__new__(Processor) + proc.triples_batch_size = triples_batch_size + return proc + + +def _make_rel(subject, predicate, obj, object_entity=True): + """Build a relationship dict as returned by the prompt client.""" + return { + "subject": subject, + "predicate": predicate, + "object": obj, + "object-entity": object_entity, + } + + +def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", + user="user-1", collection="col-1", document_id=""): + """Build a mock message wrapping a Chunk.""" + chunk = Chunk( + metadata=Metadata( + id=meta_id, root=root, user=user, collection=collection, + ), + chunk=text.encode("utf-8"), + document_id=document_id, + ) + msg = MagicMock() + msg.value.return_value = chunk + return msg + + +def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"): + """Build a mock flow callable that provides prompt client, triples + producer, and parameter specs.""" + mock_triples_pub = AsyncMock() + mock_prompt_client = AsyncMock() + mock_prompt_client.extract_relationships = AsyncMock( + return_value=prompt_result + ) + + def flow(name): + if name == "prompt-request": + return mock_prompt_client + if name == "triples": + return mock_triples_pub + if name == "llm-model": + return llm_model + if name == "ontology": + return ontology_uri + return MagicMock() + + return flow, mock_triples_pub, mock_prompt_client + + +def _sent_triples(mock_pub): + """Collect all Triples objects sent to a mock publisher.""" + return [call.args[0] for call in mock_pub.send.call_args_list] + + +def _all_triples_flat(mock_pub): + """Flatten all batches into one list of Triple objects.""" + result = [] + for triples_msg in _sent_triples(mock_pub): + result.extend(triples_msg.triples) + return result + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestDefaultBatchSize: + + def test_default_is_50(self): + assert default_triples_batch_size == 50 + + def test_processor_uses_default(self): + proc = _make_processor() + assert proc.triples_batch_size == 50 + + +class TestBatchSplitting: + + @pytest.mark.asyncio + async def test_single_batch_when_under_limit(self): + """Few triples → single send call.""" + proc = _make_processor(triples_batch_size=50) + rels = [_make_rel("A", "knows", "B")] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("some text") + + await proc.on_message(msg, MagicMock(), flow) + + # One relationship produces: rel triple + 3 labels + provenance + # All should fit in one batch of 50 + assert pub.send.call_count == 1 + + @pytest.mark.asyncio + async def test_multiple_batches_with_small_batch_size(self): + """With batch_size=3 and many triples, multiple batches are sent.""" + proc = _make_processor(triples_batch_size=3) + # 2 relationships → 2 rel triples + 6 labels = 8 triples + provenance + rels = [ + _make_rel("A", "knows", "B"), + _make_rel("C", "likes", "D"), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("some text") + + await proc.on_message(msg, MagicMock(), flow) + + # Should have more than one batch + assert pub.send.call_count > 1 + + @pytest.mark.asyncio + async def test_batch_sizes_respect_limit(self): + """No batch should exceed the configured batch size.""" + batch_size = 3 + proc = _make_processor(triples_batch_size=batch_size) + rels = [ + _make_rel("A", "knows", "B"), + _make_rel("C", "likes", "D"), + _make_rel("E", "has", "F"), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + for triples_msg in _sent_triples(pub): + assert len(triples_msg.triples) <= batch_size + + @pytest.mark.asyncio + async def test_all_triples_present_across_batches(self): + """Total triples across batches equals expected count.""" + proc = _make_processor(triples_batch_size=2) + # 1 relationship with object-entity=True → 1 rel + 3 labels = 4 triples + # + provenance triples + rels = [_make_rel("A", "knows", "B", object_entity=True)] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # At minimum: 1 rel + 3 labels = 4 content triples + assert len(all_t) >= 4 + + @pytest.mark.asyncio + async def test_custom_batch_size(self): + """Processor respects custom triples_batch_size parameter.""" + proc = _make_processor(triples_batch_size=100) + assert proc.triples_batch_size == 100 + + +class TestMetadataPreservation: + + @pytest.mark.asyncio + async def test_metadata_forwarded_to_all_batches(self): + """Every batch should carry the original chunk metadata.""" + proc = _make_processor(triples_batch_size=2) + rels = [_make_rel("X", "rel", "Y")] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg( + "text", meta_id="c-1", root="r-1", + user="u-1", collection="coll-1", + ) + + await proc.on_message(msg, MagicMock(), flow) + + for triples_msg in _sent_triples(pub): + assert triples_msg.metadata.id == "c-1" + assert triples_msg.metadata.root == "r-1" + assert triples_msg.metadata.user == "u-1" + assert triples_msg.metadata.collection == "coll-1" + + +class TestRelationshipTriples: + + @pytest.mark.asyncio + async def test_entity_object_produces_iri(self): + """object-entity=True → object is an IRI, with label triple.""" + proc = _make_processor(triples_batch_size=200) + rels = [_make_rel("Alice", "knows", "Bob", object_entity=True)] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # Find the relationship triple (not a label) + rel_triples = [ + t for t in all_t + if t.o.type == IRI and "bob" in t.o.iri + ] + assert len(rel_triples) >= 1 + + @pytest.mark.asyncio + async def test_literal_object_produces_literal(self): + """object-entity=False → object is a LITERAL, no label for object.""" + proc = _make_processor(triples_batch_size=200) + rels = [_make_rel("Alice", "age", "30", object_entity=False)] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # Find the relationship triple with literal object + lit_triples = [ + t for t in all_t + if t.o.type == LITERAL and t.o.value == "30" + ] + assert len(lit_triples) == 1 + + @pytest.mark.asyncio + async def test_labels_emitted_for_subject_and_predicate(self): + """Every relationship should produce label triples for s and p.""" + proc = _make_processor(triples_batch_size=200) + rels = [_make_rel("Alice", "knows", "Bob")] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + label_triples = [ + t for t in all_t + if t.p.type == IRI and "label" in t.p.iri.lower() + ] + labels = {t.o.value for t in label_triples} + assert "Alice" in labels + assert "knows" in labels + assert "Bob" in labels # object-entity default is True + + +class TestEmptyAndNullFiltering: + + @pytest.mark.asyncio + async def test_empty_string_fields_skipped(self): + """Relationships with empty string s/p/o are skipped.""" + proc = _make_processor(triples_batch_size=200) + rels = [ + _make_rel("", "knows", "Bob"), + _make_rel("Alice", "", "Bob"), + _make_rel("Alice", "knows", ""), + _make_rel("Good", "triple", "Here"), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # Only the "Good triple Here" relationship should produce content triples + rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri} + assert any("good" in iri for iri in rel_iris) + assert not any("alice" in iri for iri in rel_iris) + + @pytest.mark.asyncio + async def test_none_fields_skipped(self): + """Relationships with None s/p/o are skipped.""" + proc = _make_processor(triples_batch_size=200) + rels = [ + _make_rel(None, "knows", "Bob"), + _make_rel("Alice", None, "Bob"), + _make_rel("Alice", "knows", None), + _make_rel("Valid", "rel", "Here"), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri} + assert any("valid" in iri for iri in rel_iris) + assert not any("alice" in iri for iri in rel_iris) + + @pytest.mark.asyncio + async def test_all_filtered_produces_no_output(self): + """If all relationships are empty/null, nothing is emitted.""" + proc = _make_processor(triples_batch_size=200) + rels = [ + _make_rel("", "", ""), + _make_rel(None, None, None), + ] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_empty_prompt_response_produces_no_output(self): + """Empty relationship list from prompt → no triples emitted.""" + proc = _make_processor() + flow, pub, _ = _make_flow([]) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + +class TestProvenanceInclusion: + + @pytest.mark.asyncio + async def test_provenance_triples_present(self): + """Extracted relationships should include provenance triples.""" + proc = _make_processor(triples_batch_size=200) + rels = [_make_rel("A", "knows", "B")] + flow, pub, _ = _make_flow(rels) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + all_t = _all_triples_flat(pub) + # Provenance triples use GRAPH_SOURCE graph context + # They contain terms referencing prov: namespace or subgraph URIs + # We just check that total count > 4 (1 rel + 3 labels) + assert len(all_t) > 4 + + @pytest.mark.asyncio + async def test_no_provenance_when_no_extracted_triples(self): + """Empty relationships → no provenance generated.""" + proc = _make_processor() + flow, pub, _ = _make_flow([_make_rel("", "x", "y")]) + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + +class TestErrorPropagation: + + @pytest.mark.asyncio + async def test_prompt_error_is_caught(self): + """Errors from the prompt client are caught (logged, not raised).""" + proc = _make_processor() + flow, pub, prompt = _make_flow([]) + prompt.extract_relationships = AsyncMock( + side_effect=RuntimeError("LLM unavailable") + ) + msg = _make_chunk_msg("text") + + # The outer try/except in on_message catches and logs + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + @pytest.mark.asyncio + async def test_non_list_response_is_caught(self): + """Non-list prompt response triggers RuntimeError, caught by handler.""" + proc = _make_processor() + flow, pub, prompt = _make_flow("not a list") + msg = _make_chunk_msg("text") + + await proc.on_message(msg, MagicMock(), flow) + + assert pub.send.call_count == 0 + + +class TestToUri: + + def test_spaces_replaced_with_hyphens(self): + proc = _make_processor() + uri = proc.to_uri("hello world") + assert "hello-world" in uri + + def test_lowercased(self): + proc = _make_processor() + uri = proc.to_uri("Hello World") + assert "hello-world" in uri + + def test_special_chars_encoded(self): + proc = _make_processor() + # urllib.parse.quote keeps / as safe by default + uri = proc.to_uri("a/b") + assert "a/b" in uri + # Characters like spaces are encoded (handled via replace → hyphen) + uri2 = proc.to_uri("hello world") + assert " " not in uri2 diff --git a/tests/unit/test_librarian/__init__.py b/tests/unit/test_librarian/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_librarian/test_chunked_upload.py b/tests/unit/test_librarian/test_chunked_upload.py new file mode 100644 index 00000000..2d09bcd4 --- /dev/null +++ b/tests/unit/test_librarian/test_chunked_upload.py @@ -0,0 +1,716 @@ +""" +Tests for librarian chunked upload operations: +begin_upload, upload_chunk, complete_upload, abort_upload, get_upload_status, +list_uploads, and stream_document. +""" + +import base64 +import json +import math +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from trustgraph.librarian.librarian import Librarian, DEFAULT_CHUNK_SIZE +from trustgraph.exceptions import RequestError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_librarian(min_chunk_size=1): + """Create a Librarian with mocked blob_store and table_store.""" + lib = Librarian.__new__(Librarian) + lib.blob_store = MagicMock() + lib.table_store = AsyncMock() + lib.load_document = AsyncMock() + lib.min_chunk_size = min_chunk_size + return lib + + +def _make_doc_metadata( + doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc" +): + meta = MagicMock() + meta.id = doc_id + meta.kind = kind + meta.user = user + meta.title = title + meta.time = 1700000000 + meta.comments = "" + meta.tags = [] + return meta + + +def _make_begin_request( + doc_id="doc-1", kind="application/pdf", user="alice", + total_size=10_000_000, chunk_size=0 +): + req = MagicMock() + req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user) + req.total_size = total_size + req.chunk_size = chunk_size + return req + + +def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"): + req = MagicMock() + req.upload_id = upload_id + req.chunk_index = chunk_index + req.user = user + req.content = base64.b64encode(content) + return req + + +def _make_session( + user="alice", total_chunks=5, chunk_size=2_000_000, + total_size=10_000_000, chunks_received=None, object_id="obj-1", + s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1", +): + if chunks_received is None: + chunks_received = {} + if document_metadata is None: + document_metadata = json.dumps({ + "id": document_id, "kind": "application/pdf", + "user": user, "title": "Test", "time": 1700000000, + "comments": "", "tags": [], + }) + return { + "user": user, + "total_chunks": total_chunks, + "chunk_size": chunk_size, + "total_size": total_size, + "chunks_received": chunks_received, + "object_id": object_id, + "s3_upload_id": s3_upload_id, + "document_metadata": document_metadata, + "document_id": document_id, + } + + +# --------------------------------------------------------------------------- +# begin_upload +# --------------------------------------------------------------------------- + +class TestBeginUpload: + + @pytest.mark.asyncio + async def test_creates_session(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-upload-id" + + req = _make_begin_request(total_size=10_000_000) + resp = await lib.begin_upload(req) + + assert resp.error is None + assert resp.upload_id is not None + assert resp.total_chunks == math.ceil(10_000_000 / DEFAULT_CHUNK_SIZE) + assert resp.chunk_size == DEFAULT_CHUNK_SIZE + + @pytest.mark.asyncio + async def test_custom_chunk_size(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-id" + + req = _make_begin_request(total_size=10_000, chunk_size=3000) + resp = await lib.begin_upload(req) + + assert resp.chunk_size == 3000 + assert resp.total_chunks == math.ceil(10_000 / 3000) + + @pytest.mark.asyncio + async def test_rejects_invalid_kind(self): + lib = _make_librarian() + req = _make_begin_request(kind="image/png") + + with pytest.raises(RequestError, match="Invalid document kind"): + await lib.begin_upload(req) + + @pytest.mark.asyncio + async def test_rejects_duplicate_document(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = True + + req = _make_begin_request() + with pytest.raises(RequestError, match="already exists"): + await lib.begin_upload(req) + + @pytest.mark.asyncio + async def test_rejects_zero_size(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + + req = _make_begin_request(total_size=0) + with pytest.raises(RequestError, match="positive"): + await lib.begin_upload(req) + + @pytest.mark.asyncio + async def test_rejects_chunk_below_minimum(self): + lib = _make_librarian(min_chunk_size=1024) + lib.table_store.document_exists.return_value = False + + req = _make_begin_request(total_size=10_000, chunk_size=512) + with pytest.raises(RequestError, match="below minimum"): + await lib.begin_upload(req) + + @pytest.mark.asyncio + async def test_calls_s3_create_multipart(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-id" + + req = _make_begin_request(kind="application/pdf") + await lib.begin_upload(req) + + lib.blob_store.create_multipart_upload.assert_called_once() + # create_multipart_upload(object_id, kind) — positional args + args = lib.blob_store.create_multipart_upload.call_args[0] + assert args[1] == "application/pdf" + + @pytest.mark.asyncio + async def test_stores_session_in_cassandra(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-id" + + req = _make_begin_request(total_size=5_000_000) + resp = await lib.begin_upload(req) + + lib.table_store.create_upload_session.assert_called_once() + kwargs = lib.table_store.create_upload_session.call_args[1] + assert kwargs["upload_id"] == resp.upload_id + assert kwargs["total_size"] == 5_000_000 + assert kwargs["total_chunks"] == resp.total_chunks + + @pytest.mark.asyncio + async def test_accepts_text_plain(self): + lib = _make_librarian() + lib.table_store.document_exists.return_value = False + lib.blob_store.create_multipart_upload.return_value = "s3-id" + + req = _make_begin_request(kind="text/plain", total_size=1000) + resp = await lib.begin_upload(req) + assert resp.error is None + + +# --------------------------------------------------------------------------- +# upload_chunk +# --------------------------------------------------------------------------- + +class TestUploadChunk: + + @pytest.mark.asyncio + async def test_successful_chunk_upload(self): + lib = _make_librarian() + session = _make_session(total_chunks=5, chunks_received={}) + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "etag-1" + + req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data") + resp = await lib.upload_chunk(req) + + assert resp.error is None + assert resp.chunk_index == 0 + assert resp.total_chunks == 5 + # The chunk is added to the dict (len=1), then +1 applied => 2 + assert resp.chunks_received == 2 + + @pytest.mark.asyncio + async def test_s3_part_number_is_1_indexed(self): + lib = _make_librarian() + session = _make_session() + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "etag" + + req = _make_upload_chunk_request(chunk_index=0) + await lib.upload_chunk(req) + + kwargs = lib.blob_store.upload_part.call_args[1] + assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part + + @pytest.mark.asyncio + async def test_chunk_index_3_becomes_part_4(self): + lib = _make_librarian() + session = _make_session() + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "etag" + + req = _make_upload_chunk_request(chunk_index=3) + await lib.upload_chunk(req) + + kwargs = lib.blob_store.upload_part.call_args[1] + assert kwargs["part_number"] == 4 + + @pytest.mark.asyncio + async def test_rejects_expired_session(self): + lib = _make_librarian() + lib.table_store.get_upload_session.return_value = None + + req = _make_upload_chunk_request() + with pytest.raises(RequestError, match="not found"): + await lib.upload_chunk(req) + + @pytest.mark.asyncio + async def test_rejects_wrong_user(self): + lib = _make_librarian() + session = _make_session(user="alice") + lib.table_store.get_upload_session.return_value = session + + req = _make_upload_chunk_request(user="bob") + with pytest.raises(RequestError, match="Not authorized"): + await lib.upload_chunk(req) + + @pytest.mark.asyncio + async def test_rejects_negative_chunk_index(self): + lib = _make_librarian() + session = _make_session(total_chunks=5) + lib.table_store.get_upload_session.return_value = session + + req = _make_upload_chunk_request(chunk_index=-1) + with pytest.raises(RequestError, match="Invalid chunk index"): + await lib.upload_chunk(req) + + @pytest.mark.asyncio + async def test_rejects_out_of_range_chunk_index(self): + lib = _make_librarian() + session = _make_session(total_chunks=5) + lib.table_store.get_upload_session.return_value = session + + req = _make_upload_chunk_request(chunk_index=5) + with pytest.raises(RequestError, match="Invalid chunk index"): + await lib.upload_chunk(req) + + @pytest.mark.asyncio + async def test_progress_tracking(self): + lib = _make_librarian() + session = _make_session( + total_chunks=4, chunk_size=1000, total_size=3500, + chunks_received={0: "e1", 1: "e2"}, + ) + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "e3" + + req = _make_upload_chunk_request(chunk_index=2) + resp = await lib.upload_chunk(req) + + # Dict gets chunk 2 added (len=3), then +1 => 4 + assert resp.chunks_received == 4 + assert resp.total_chunks == 4 + assert resp.total_bytes == 3500 + + @pytest.mark.asyncio + async def test_bytes_capped_at_total_size(self): + """bytes_received should not exceed total_size for the final chunk.""" + lib = _make_librarian() + session = _make_session( + total_chunks=2, chunk_size=3000, total_size=5000, + chunks_received={0: "e1"}, + ) + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "e2" + + req = _make_upload_chunk_request(chunk_index=1) + resp = await lib.upload_chunk(req) + + # 3 chunks × 3000 = 9000 > 5000, so capped + assert resp.bytes_received <= 5000 + + @pytest.mark.asyncio + async def test_base64_decodes_content(self): + lib = _make_librarian() + session = _make_session() + lib.table_store.get_upload_session.return_value = session + lib.blob_store.upload_part.return_value = "etag" + + raw = b"hello world binary data" + req = _make_upload_chunk_request(content=raw) + await lib.upload_chunk(req) + + kwargs = lib.blob_store.upload_part.call_args[1] + assert kwargs["data"] == raw + + +# --------------------------------------------------------------------------- +# complete_upload +# --------------------------------------------------------------------------- + +class TestCompleteUpload: + + @pytest.mark.asyncio + async def test_successful_completion(self): + lib = _make_librarian() + session = _make_session( + total_chunks=3, + chunks_received={0: "e1", 1: "e2", 2: "e3"}, + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + resp = await lib.complete_upload(req) + + assert resp.error is None + assert resp.document_id == "doc-1" + lib.blob_store.complete_multipart_upload.assert_called_once() + lib.table_store.add_document.assert_called_once() + lib.table_store.delete_upload_session.assert_called_once_with("up-1") + + @pytest.mark.asyncio + async def test_parts_sorted_by_index(self): + lib = _make_librarian() + # Chunks received out of order + session = _make_session( + total_chunks=3, + chunks_received={2: "e3", 0: "e1", 1: "e2"}, + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + await lib.complete_upload(req) + + parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"] + part_numbers = [p[0] for p in parts] + assert part_numbers == [1, 2, 3] # Sorted, 1-indexed + + @pytest.mark.asyncio + async def test_rejects_missing_chunks(self): + lib = _make_librarian() + session = _make_session( + total_chunks=3, + chunks_received={0: "e1", 2: "e3"}, # chunk 1 missing + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + with pytest.raises(RequestError, match="Missing chunks"): + await lib.complete_upload(req) + + @pytest.mark.asyncio + async def test_rejects_expired_session(self): + lib = _make_librarian() + lib.table_store.get_upload_session.return_value = None + + req = MagicMock() + req.upload_id = "up-gone" + req.user = "alice" + + with pytest.raises(RequestError, match="not found"): + await lib.complete_upload(req) + + @pytest.mark.asyncio + async def test_rejects_wrong_user(self): + lib = _make_librarian() + session = _make_session(user="alice") + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "bob" + + with pytest.raises(RequestError, match="Not authorized"): + await lib.complete_upload(req) + + +# --------------------------------------------------------------------------- +# abort_upload +# --------------------------------------------------------------------------- + +class TestAbortUpload: + + @pytest.mark.asyncio + async def test_aborts_and_cleans_up(self): + lib = _make_librarian() + session = _make_session() + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + resp = await lib.abort_upload(req) + + assert resp.error is None + lib.blob_store.abort_multipart_upload.assert_called_once_with( + object_id="obj-1", upload_id="s3-up-1" + ) + lib.table_store.delete_upload_session.assert_called_once_with("up-1") + + @pytest.mark.asyncio + async def test_rejects_expired_session(self): + lib = _make_librarian() + lib.table_store.get_upload_session.return_value = None + + req = MagicMock() + req.upload_id = "up-gone" + req.user = "alice" + + with pytest.raises(RequestError, match="not found"): + await lib.abort_upload(req) + + @pytest.mark.asyncio + async def test_rejects_wrong_user(self): + lib = _make_librarian() + session = _make_session(user="alice") + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "bob" + + with pytest.raises(RequestError, match="Not authorized"): + await lib.abort_upload(req) + + +# --------------------------------------------------------------------------- +# get_upload_status +# --------------------------------------------------------------------------- + +class TestGetUploadStatus: + + @pytest.mark.asyncio + async def test_in_progress_status(self): + lib = _make_librarian() + session = _make_session( + total_chunks=5, chunk_size=2000, total_size=10_000, + chunks_received={0: "e1", 2: "e3", 4: "e5"}, + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + resp = await lib.get_upload_status(req) + + assert resp.upload_state == "in-progress" + assert resp.chunks_received == 3 + assert resp.total_chunks == 5 + assert sorted(resp.received_chunks) == [0, 2, 4] + assert sorted(resp.missing_chunks) == [1, 3] + assert resp.total_bytes == 10_000 + + @pytest.mark.asyncio + async def test_expired_session(self): + lib = _make_librarian() + lib.table_store.get_upload_session.return_value = None + + req = MagicMock() + req.upload_id = "up-expired" + req.user = "alice" + + resp = await lib.get_upload_status(req) + + assert resp.upload_state == "expired" + + @pytest.mark.asyncio + async def test_all_chunks_received(self): + lib = _make_librarian() + session = _make_session( + total_chunks=3, chunk_size=1000, total_size=2500, + chunks_received={0: "e1", 1: "e2", 2: "e3"}, + ) + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "alice" + + resp = await lib.get_upload_status(req) + + assert resp.missing_chunks == [] + assert resp.chunks_received == 3 + # 3 * 1000 = 3000 > 2500, so capped + assert resp.bytes_received <= 2500 + + @pytest.mark.asyncio + async def test_rejects_wrong_user(self): + lib = _make_librarian() + session = _make_session(user="alice") + lib.table_store.get_upload_session.return_value = session + + req = MagicMock() + req.upload_id = "up-1" + req.user = "bob" + + with pytest.raises(RequestError, match="Not authorized"): + await lib.get_upload_status(req) + + +# --------------------------------------------------------------------------- +# stream_document +# --------------------------------------------------------------------------- + +class TestStreamDocument: + + @pytest.mark.asyncio + async def test_streams_chunks_with_progress(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=5000) + lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 2000 + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + assert len(chunks) == 3 # ceil(5000/2000) + assert chunks[0].chunk_index == 0 + assert chunks[0].total_chunks == 3 + assert chunks[0].is_final is False + assert chunks[-1].is_final is True + assert chunks[-1].chunk_index == 2 + + @pytest.mark.asyncio + async def test_single_chunk_document(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=500) + lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 2000 + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + assert len(chunks) == 1 + assert chunks[0].is_final is True + assert chunks[0].bytes_received == 500 + assert chunks[0].total_bytes == 500 + + @pytest.mark.asyncio + async def test_byte_ranges_correct(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=5000) + lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 2000 + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + # Verify the byte ranges passed to get_range + calls = lib.blob_store.get_range.call_args_list + assert calls[0][0] == ("obj-1", 0, 2000) + assert calls[1][0] == ("obj-1", 2000, 2000) + assert calls[2][0] == ("obj-1", 4000, 1000) # Last chunk: 5000-4000 + + @pytest.mark.asyncio + async def test_default_chunk_size(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=2_000_000) + lib.blob_store.get_range = AsyncMock(return_value=b"x") + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 0 # Should use default 1MB + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + assert len(chunks) == 2 # ceil(2MB / 1MB) + + @pytest.mark.asyncio + async def test_content_is_base64_encoded(self): + lib = _make_librarian() + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=100) + raw = b"hello world" + lib.blob_store.get_range = AsyncMock(return_value=raw) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 1000 + + chunks = [] + async for resp in lib.stream_document(req): + chunks.append(resp) + + assert chunks[0].content == base64.b64encode(raw) + + @pytest.mark.asyncio + async def test_rejects_chunk_below_minimum(self): + lib = _make_librarian(min_chunk_size=1024) + lib.table_store.get_document_object_id.return_value = "obj-1" + lib.blob_store.get_size = AsyncMock(return_value=5000) + + req = MagicMock() + req.user = "alice" + req.document_id = "doc-1" + req.chunk_size = 512 + + with pytest.raises(RequestError, match="below minimum"): + async for _ in lib.stream_document(req): + pass + + +# --------------------------------------------------------------------------- +# list_uploads +# --------------------------------------------------------------------------- + +class TestListUploads: + + @pytest.mark.asyncio + async def test_returns_sessions(self): + lib = _make_librarian() + lib.table_store.list_upload_sessions.return_value = [ + { + "upload_id": "up-1", + "document_id": "doc-1", + "document_metadata": '{"id":"doc-1"}', + "total_size": 10000, + "chunk_size": 2000, + "total_chunks": 5, + "chunks_received": {0: "e1", 1: "e2"}, + "created_at": "2024-01-01", + }, + ] + + req = MagicMock() + req.user = "alice" + + resp = await lib.list_uploads(req) + + assert resp.error is None + assert len(resp.upload_sessions) == 1 + assert resp.upload_sessions[0].upload_id == "up-1" + assert resp.upload_sessions[0].total_chunks == 5 + + @pytest.mark.asyncio + async def test_empty_uploads(self): + lib = _make_librarian() + lib.table_store.list_upload_sessions.return_value = [] + + req = MagicMock() + req.user = "alice" + + resp = await lib.list_uploads(req) + + assert resp.upload_sessions == [] diff --git a/tests/unit/test_provenance/__init__.py b/tests/unit/test_provenance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_provenance/test_agent_provenance.py b/tests/unit/test_provenance/test_agent_provenance.py new file mode 100644 index 00000000..9377fe19 --- /dev/null +++ b/tests/unit/test_provenance/test_agent_provenance.py @@ -0,0 +1,324 @@ +""" +Tests for agent provenance triple builder functions. +""" + +import json +import pytest + +from trustgraph.schema import Triple, Term, IRI, LITERAL + +from trustgraph.provenance.agent import ( + agent_session_triples, + agent_iteration_triples, + agent_final_triples, +) + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME, + TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, + TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, + TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, + TG_AGENT_QUESTION, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + for t in triples: + if (t.s.iri == subject and t.p.iri == RDF_TYPE + and t.o.type == IRI and t.o.iri == rdf_type): + return True + return False + + +# --------------------------------------------------------------------------- +# agent_session_triples +# --------------------------------------------------------------------------- + +class TestAgentSessionTriples: + + SESSION_URI = "urn:trustgraph:agent:test-session" + + def test_session_types(self): + triples = agent_session_triples( + self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z" + ) + assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY) + assert has_type(triples, self.SESSION_URI, TG_QUESTION) + assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION) + + def test_session_query_text(self): + triples = agent_session_triples( + self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z" + ) + query = find_triple(triples, TG_QUERY, self.SESSION_URI) + assert query is not None + assert query.o.value == "What is X?" + + def test_session_timestamp(self): + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-06-15T10:00:00Z" + ) + ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI) + assert ts is not None + assert ts.o.value == "2024-06-15T10:00:00Z" + + def test_session_default_timestamp(self): + triples = agent_session_triples(self.SESSION_URI, "Q") + ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI) + assert ts is not None + assert len(ts.o.value) > 0 + + def test_session_label(self): + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-01-01T00:00:00Z" + ) + label = find_triple(triples, RDFS_LABEL, self.SESSION_URI) + assert label is not None + assert label.o.value == "Agent Question" + + def test_session_triple_count(self): + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-01-01T00:00:00Z" + ) + assert len(triples) == 6 + + +# --------------------------------------------------------------------------- +# agent_iteration_triples +# --------------------------------------------------------------------------- + +class TestAgentIterationTriples: + + ITER_URI = "urn:trustgraph:agent:test-session/i1" + PARENT_URI = "urn:trustgraph:agent:test-session" + + def test_iteration_types(self): + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + thought="thinking", action="search", observation="found it", + ) + assert has_type(triples, self.ITER_URI, PROV_ENTITY) + assert has_type(triples, self.ITER_URI, TG_ANALYSIS) + + def test_iteration_derived_from_parent(self): + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + action="search", + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI) + assert derived is not None + assert derived.o.iri == self.PARENT_URI + + def test_iteration_label_includes_action(self): + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + action="graph-rag-query", + ) + label = find_triple(triples, RDFS_LABEL, self.ITER_URI) + assert label is not None + assert "graph-rag-query" in label.o.value + + def test_iteration_thought_inline(self): + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + thought="I need to search for info", + action="search", + ) + thought = find_triple(triples, TG_THOUGHT, self.ITER_URI) + assert thought is not None + assert thought.o.value == "I need to search for info" + + def test_iteration_thought_document_preferred(self): + """When thought_document_id is provided, inline thought is not stored.""" + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + thought="inline thought", + action="search", + thought_document_id="urn:doc:thought-1", + ) + thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI) + assert thought_doc is not None + assert thought_doc.o.iri == "urn:doc:thought-1" + thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI) + assert thought_inline is None + + def test_iteration_observation_inline(self): + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + action="search", + observation="Found 3 results", + ) + obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI) + assert obs is not None + assert obs.o.value == "Found 3 results" + + def test_iteration_observation_document_preferred(self): + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + action="search", + observation="inline obs", + observation_document_id="urn:doc:obs-1", + ) + obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI) + assert obs_doc is not None + assert obs_doc.o.iri == "urn:doc:obs-1" + obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI) + assert obs_inline is None + + def test_iteration_action_recorded(self): + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + action="graph-rag-query", + ) + action = find_triple(triples, TG_ACTION, self.ITER_URI) + assert action is not None + assert action.o.value == "graph-rag-query" + + def test_iteration_arguments_json_encoded(self): + args = {"query": "test query", "limit": 10} + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + action="search", + arguments=args, + ) + arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI) + assert arguments is not None + parsed = json.loads(arguments.o.value) + assert parsed == args + + def test_iteration_default_arguments_empty_dict(self): + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + action="search", + ) + arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI) + assert arguments is not None + parsed = json.loads(arguments.o.value) + assert parsed == {} + + def test_iteration_no_thought_or_observation(self): + """Minimal iteration with just action — no thought or observation triples.""" + triples = agent_iteration_triples( + self.ITER_URI, self.PARENT_URI, + action="noop", + ) + thought = find_triple(triples, TG_THOUGHT, self.ITER_URI) + obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI) + assert thought is None + assert obs is None + + def test_iteration_chaining(self): + """Second iteration derives from first iteration, not session.""" + iter1_uri = "urn:trustgraph:agent:sess/i1" + iter2_uri = "urn:trustgraph:agent:sess/i2" + + triples1 = agent_iteration_triples( + iter1_uri, self.PARENT_URI, action="step1", + ) + triples2 = agent_iteration_triples( + iter2_uri, iter1_uri, action="step2", + ) + + derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri) + assert derived1.o.iri == self.PARENT_URI + + derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri) + assert derived2.o.iri == iter1_uri + + +# --------------------------------------------------------------------------- +# agent_final_triples +# --------------------------------------------------------------------------- + +class TestAgentFinalTriples: + + FINAL_URI = "urn:trustgraph:agent:test-session/final" + PARENT_URI = "urn:trustgraph:agent:test-session/i3" + + def test_final_types(self): + triples = agent_final_triples( + self.FINAL_URI, self.PARENT_URI, answer="42" + ) + assert has_type(triples, self.FINAL_URI, PROV_ENTITY) + assert has_type(triples, self.FINAL_URI, TG_CONCLUSION) + + def test_final_derived_from_parent(self): + triples = agent_final_triples( + self.FINAL_URI, self.PARENT_URI, answer="42" + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) + assert derived is not None + assert derived.o.iri == self.PARENT_URI + + def test_final_label(self): + triples = agent_final_triples( + self.FINAL_URI, self.PARENT_URI, answer="42" + ) + label = find_triple(triples, RDFS_LABEL, self.FINAL_URI) + assert label is not None + assert label.o.value == "Conclusion" + + def test_final_inline_answer(self): + triples = agent_final_triples( + self.FINAL_URI, self.PARENT_URI, answer="The answer is 42" + ) + answer = find_triple(triples, TG_ANSWER, self.FINAL_URI) + assert answer is not None + assert answer.o.value == "The answer is 42" + + def test_final_document_reference(self): + triples = agent_final_triples( + self.FINAL_URI, self.PARENT_URI, + document_id="urn:trustgraph:agent:sess/answer", + ) + doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) + assert doc is not None + assert doc.o.type == IRI + assert doc.o.iri == "urn:trustgraph:agent:sess/answer" + + def test_final_document_takes_precedence(self): + triples = agent_final_triples( + self.FINAL_URI, self.PARENT_URI, + answer="inline", + document_id="urn:doc:123", + ) + doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) + assert doc is not None + answer = find_triple(triples, TG_ANSWER, self.FINAL_URI) + assert answer is None + + def test_final_no_answer_or_document(self): + triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI) + answer = find_triple(triples, TG_ANSWER, self.FINAL_URI) + doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) + assert answer is None + assert doc is None + + def test_final_derives_from_session_when_no_iterations(self): + """When agent answers immediately, final derives from session.""" + session_uri = "urn:trustgraph:agent:test-session" + triples = agent_final_triples( + self.FINAL_URI, session_uri, answer="direct answer" + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) + assert derived.o.iri == session_uri diff --git a/tests/unit/test_provenance/test_explainability.py b/tests/unit/test_provenance/test_explainability.py new file mode 100644 index 00000000..1f27cc61 --- /dev/null +++ b/tests/unit/test_provenance/test_explainability.py @@ -0,0 +1,563 @@ +""" +Tests for the explainability API (entity parsing, wire format conversion, +and ExplainabilityClient). +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.api.explainability import ( + EdgeSelection, + ExplainEntity, + Question, + Exploration, + Focus, + Synthesis, + Analysis, + Conclusion, + parse_edge_selection_triples, + extract_term_value, + wire_triples_to_tuples, + ExplainabilityClient, + TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT, + TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, + TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, + TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_ANALYSIS, TG_CONCLUSION, + TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, + PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + RDF_TYPE, RDFS_LABEL, +) + + +# --------------------------------------------------------------------------- +# Entity from_triples parsing +# --------------------------------------------------------------------------- + +class TestExplainEntityFromTriples: + """Test ExplainEntity.from_triples dispatches to correct subclass.""" + + def test_graphrag_question(self): + triples = [ + ("urn:q:1", RDF_TYPE, TG_QUESTION), + ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION), + ("urn:q:1", TG_QUERY, "What is AI?"), + ("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"), + ] + entity = ExplainEntity.from_triples("urn:q:1", triples) + assert isinstance(entity, Question) + assert entity.query == "What is AI?" + assert entity.timestamp == "2024-01-01T00:00:00Z" + assert entity.question_type == "graph-rag" + + def test_docrag_question(self): + triples = [ + ("urn:q:2", RDF_TYPE, TG_QUESTION), + ("urn:q:2", RDF_TYPE, TG_DOC_RAG_QUESTION), + ("urn:q:2", TG_QUERY, "Find info"), + ] + entity = ExplainEntity.from_triples("urn:q:2", triples) + assert isinstance(entity, Question) + assert entity.question_type == "document-rag" + + def test_agent_question(self): + triples = [ + ("urn:q:3", RDF_TYPE, TG_QUESTION), + ("urn:q:3", RDF_TYPE, TG_AGENT_QUESTION), + ("urn:q:3", TG_QUERY, "Agent query"), + ] + entity = ExplainEntity.from_triples("urn:q:3", triples) + assert isinstance(entity, Question) + assert entity.question_type == "agent" + + def test_exploration(self): + triples = [ + ("urn:exp:1", RDF_TYPE, TG_EXPLORATION), + ("urn:exp:1", TG_EDGE_COUNT, "15"), + ] + entity = ExplainEntity.from_triples("urn:exp:1", triples) + assert isinstance(entity, Exploration) + assert entity.edge_count == 15 + + def test_exploration_with_chunk_count(self): + triples = [ + ("urn:exp:2", RDF_TYPE, TG_EXPLORATION), + ("urn:exp:2", TG_CHUNK_COUNT, "5"), + ] + entity = ExplainEntity.from_triples("urn:exp:2", triples) + assert isinstance(entity, Exploration) + assert entity.chunk_count == 5 + + def test_exploration_invalid_count(self): + triples = [ + ("urn:exp:3", RDF_TYPE, TG_EXPLORATION), + ("urn:exp:3", TG_EDGE_COUNT, "not-a-number"), + ] + entity = ExplainEntity.from_triples("urn:exp:3", triples) + assert isinstance(entity, Exploration) + assert entity.edge_count == 0 + + def test_focus(self): + triples = [ + ("urn:foc:1", RDF_TYPE, TG_FOCUS), + ("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:1"), + ("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:2"), + ] + entity = ExplainEntity.from_triples("urn:foc:1", triples) + assert isinstance(entity, Focus) + assert len(entity.selected_edge_uris) == 2 + assert "urn:edge:1" in entity.selected_edge_uris + assert "urn:edge:2" in entity.selected_edge_uris + + def test_synthesis_with_content(self): + triples = [ + ("urn:syn:1", RDF_TYPE, TG_SYNTHESIS), + ("urn:syn:1", TG_CONTENT, "The answer is 42"), + ] + entity = ExplainEntity.from_triples("urn:syn:1", triples) + assert isinstance(entity, Synthesis) + assert entity.content == "The answer is 42" + assert entity.document_uri == "" + + def test_synthesis_with_document(self): + triples = [ + ("urn:syn:2", RDF_TYPE, TG_SYNTHESIS), + ("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"), + ] + entity = ExplainEntity.from_triples("urn:syn:2", triples) + assert isinstance(entity, Synthesis) + assert entity.document_uri == "urn:doc:answer-1" + + def test_analysis(self): + triples = [ + ("urn:ana:1", RDF_TYPE, TG_ANALYSIS), + ("urn:ana:1", TG_THOUGHT, "I should search"), + ("urn:ana:1", TG_ACTION, "graph-rag-query"), + ("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'), + ("urn:ana:1", TG_OBSERVATION, "Found results"), + ] + entity = ExplainEntity.from_triples("urn:ana:1", triples) + assert isinstance(entity, Analysis) + assert entity.thought == "I should search" + assert entity.action == "graph-rag-query" + assert entity.arguments == '{"query": "test"}' + assert entity.observation == "Found results" + + def test_analysis_with_document_refs(self): + triples = [ + ("urn:ana:2", RDF_TYPE, TG_ANALYSIS), + ("urn:ana:2", TG_ACTION, "search"), + ("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"), + ("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"), + ] + entity = ExplainEntity.from_triples("urn:ana:2", triples) + assert isinstance(entity, Analysis) + assert entity.thought_document_uri == "urn:doc:thought-1" + assert entity.observation_document_uri == "urn:doc:obs-1" + + def test_conclusion_with_answer(self): + triples = [ + ("urn:conc:1", RDF_TYPE, TG_CONCLUSION), + ("urn:conc:1", TG_ANSWER, "The final answer"), + ] + entity = ExplainEntity.from_triples("urn:conc:1", triples) + assert isinstance(entity, Conclusion) + assert entity.answer == "The final answer" + + def test_conclusion_with_document(self): + triples = [ + ("urn:conc:2", RDF_TYPE, TG_CONCLUSION), + ("urn:conc:2", TG_DOCUMENT, "urn:doc:final"), + ] + entity = ExplainEntity.from_triples("urn:conc:2", triples) + assert isinstance(entity, Conclusion) + assert entity.document_uri == "urn:doc:final" + + def test_unknown_type(self): + triples = [ + ("urn:x:1", RDF_TYPE, "http://example.com/UnknownType"), + ] + entity = ExplainEntity.from_triples("urn:x:1", triples) + assert isinstance(entity, ExplainEntity) + assert entity.entity_type == "unknown" + + +# --------------------------------------------------------------------------- +# parse_edge_selection_triples +# --------------------------------------------------------------------------- + +class TestParseEdgeSelectionTriples: + + def test_with_edge_and_reasoning(self): + triples = [ + ("urn:edge:1", TG_EDGE, {"s": "Alice", "p": "knows", "o": "Bob"}), + ("urn:edge:1", TG_REASONING, "Alice and Bob are connected"), + ] + result = parse_edge_selection_triples(triples) + assert isinstance(result, EdgeSelection) + assert result.uri == "urn:edge:1" + assert result.edge == {"s": "Alice", "p": "knows", "o": "Bob"} + assert result.reasoning == "Alice and Bob are connected" + + def test_with_edge_only(self): + triples = [ + ("urn:edge:2", TG_EDGE, {"s": "A", "p": "r", "o": "B"}), + ] + result = parse_edge_selection_triples(triples) + assert result.edge is not None + assert result.reasoning == "" + + def test_with_reasoning_only(self): + triples = [ + ("urn:edge:3", TG_REASONING, "some reason"), + ] + result = parse_edge_selection_triples(triples) + assert result.edge is None + assert result.reasoning == "some reason" + + def test_empty_triples(self): + result = parse_edge_selection_triples([]) + assert result.uri == "" + assert result.edge is None + assert result.reasoning == "" + + def test_edge_must_be_dict(self): + """Non-dict values for TG_EDGE should not be treated as edges.""" + triples = [ + ("urn:edge:4", TG_EDGE, "not-a-dict"), + ] + result = parse_edge_selection_triples(triples) + assert result.edge is None + + +# --------------------------------------------------------------------------- +# extract_term_value +# --------------------------------------------------------------------------- + +class TestExtractTermValue: + + def test_iri_short_format(self): + assert extract_term_value({"t": "i", "i": "urn:test"}) == "urn:test" + + def test_iri_long_format(self): + assert extract_term_value({"type": "i", "iri": "urn:test"}) == "urn:test" + + def test_literal_short_format(self): + assert extract_term_value({"t": "l", "v": "hello"}) == "hello" + + def test_literal_long_format(self): + assert extract_term_value({"type": "l", "value": "hello"}) == "hello" + + def test_quoted_triple(self): + term = { + "t": "t", + "tr": { + "s": {"t": "i", "i": "urn:s"}, + "p": {"t": "i", "i": "urn:p"}, + "o": {"t": "i", "i": "urn:o"}, + } + } + result = extract_term_value(term) + assert result == {"s": "urn:s", "p": "urn:p", "o": "urn:o"} + + def test_quoted_triple_long_format(self): + term = { + "type": "t", + "triple": { + "s": {"type": "i", "iri": "urn:s"}, + "p": {"type": "i", "iri": "urn:p"}, + "o": {"type": "l", "value": "val"}, + } + } + result = extract_term_value(term) + assert result == {"s": "urn:s", "p": "urn:p", "o": "val"} + + def test_unknown_type_fallback(self): + result = extract_term_value({"t": "x", "i": "urn:fallback"}) + assert result == "urn:fallback" + + +# --------------------------------------------------------------------------- +# wire_triples_to_tuples +# --------------------------------------------------------------------------- + +class TestWireTriplesToTuples: + + def test_basic_conversion(self): + wire = [ + { + "s": {"t": "i", "i": "urn:s1"}, + "p": {"t": "i", "i": "urn:p1"}, + "o": {"t": "l", "v": "value1"}, + }, + ] + result = wire_triples_to_tuples(wire) + assert len(result) == 1 + assert result[0] == ("urn:s1", "urn:p1", "value1") + + def test_multiple_triples(self): + wire = [ + { + "s": {"t": "i", "i": "urn:s1"}, + "p": {"t": "i", "i": "urn:p1"}, + "o": {"t": "l", "v": "v1"}, + }, + { + "s": {"t": "i", "i": "urn:s2"}, + "p": {"t": "i", "i": "urn:p2"}, + "o": {"t": "i", "i": "urn:o2"}, + }, + ] + result = wire_triples_to_tuples(wire) + assert len(result) == 2 + assert result[0] == ("urn:s1", "urn:p1", "v1") + assert result[1] == ("urn:s2", "urn:p2", "urn:o2") + + def test_empty_list(self): + assert wire_triples_to_tuples([]) == [] + + def test_missing_fields(self): + wire = [{"s": {}, "p": {}, "o": {}}] + result = wire_triples_to_tuples(wire) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# ExplainabilityClient +# --------------------------------------------------------------------------- + +def _make_wire_triples(tuples): + """Convert (s, p, o) tuples to wire format for mocking.""" + result = [] + for s, p, o in tuples: + entry = { + "s": {"t": "i", "i": s}, + "p": {"t": "i", "i": p}, + } + if o.startswith("urn:") or o.startswith("http"): + entry["o"] = {"t": "i", "i": o} + else: + entry["o"] = {"t": "l", "v": o} + result.append(entry) + return result + + +class TestExplainabilityClientFetchEntity: + + def test_fetch_question_entity(self): + wire = _make_wire_triples([ + ("urn:q:1", RDF_TYPE, TG_QUESTION), + ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION), + ("urn:q:1", TG_QUERY, "What is AI?"), + ("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"), + ]) + + mock_flow = MagicMock() + # Return same results twice for quiescence + mock_flow.triples_query.side_effect = [wire, wire] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + entity = client.fetch_entity("urn:q:1", graph="urn:graph:retrieval") + + assert isinstance(entity, Question) + assert entity.query == "What is AI?" + assert entity.question_type == "graph-rag" + + def test_fetch_returns_none_when_no_data(self): + mock_flow = MagicMock() + mock_flow.triples_query.return_value = [] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2) + entity = client.fetch_entity("urn:nonexistent") + + assert entity is None + + def test_fetch_retries_on_empty_results(self): + wire = _make_wire_triples([ + ("urn:q:1", RDF_TYPE, TG_QUESTION), + ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION), + ("urn:q:1", TG_QUERY, "Q"), + ]) + + mock_flow = MagicMock() + # Empty, then data, then same data (stable) + mock_flow.triples_query.side_effect = [[], wire, wire] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + entity = client.fetch_entity("urn:q:1") + + assert isinstance(entity, Question) + assert mock_flow.triples_query.call_count == 3 + + +class TestExplainabilityClientResolveLabel: + + def test_resolve_label_found(self): + mock_flow = MagicMock() + mock_flow.triples_query.return_value = _make_wire_triples([ + ("urn:entity:1", RDFS_LABEL, "Entity One"), + ]) + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + label = client.resolve_label("urn:entity:1") + assert label == "Entity One" + + def test_resolve_label_not_found(self): + mock_flow = MagicMock() + mock_flow.triples_query.return_value = [] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + label = client.resolve_label("urn:entity:1") + assert label == "urn:entity:1" + + def test_resolve_label_cached(self): + mock_flow = MagicMock() + mock_flow.triples_query.return_value = _make_wire_triples([ + ("urn:entity:1", RDFS_LABEL, "Entity One"), + ]) + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + client.resolve_label("urn:entity:1") + client.resolve_label("urn:entity:1") + + # Only one query should be made + assert mock_flow.triples_query.call_count == 1 + + def test_resolve_label_non_uri(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + assert client.resolve_label("plain text") == "plain text" + assert client.resolve_label("") == "" + mock_flow.triples_query.assert_not_called() + + def test_resolve_edge_labels(self): + mock_flow = MagicMock() + + def mock_query(s=None, p=None, **kwargs): + labels = { + "urn:e:Alice": "Alice", + "urn:r:knows": "knows", + "urn:e:Bob": "Bob", + } + if s in labels: + return _make_wire_triples([(s, RDFS_LABEL, labels[s])]) + return [] + + mock_flow.triples_query.side_effect = mock_query + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + s, p, o = client.resolve_edge_labels( + {"s": "urn:e:Alice", "p": "urn:r:knows", "o": "urn:e:Bob"} + ) + assert s == "Alice" + assert p == "knows" + assert o == "Bob" + + +class TestExplainabilityClientContentFetching: + + def test_fetch_synthesis_inline_content(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + + synthesis = Synthesis(uri="urn:syn:1", content="inline answer") + result = client.fetch_synthesis_content(synthesis, api=None) + assert result == "inline answer" + + def test_fetch_synthesis_truncated_content(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + + long_content = "x" * 20000 + synthesis = Synthesis(uri="urn:syn:1", content=long_content) + result = client.fetch_synthesis_content(synthesis, api=None, max_content=100) + assert len(result) < 20000 + assert result.endswith("... [truncated]") + + def test_fetch_synthesis_from_librarian(self): + mock_flow = MagicMock() + mock_api = MagicMock() + mock_library = MagicMock() + mock_api.library.return_value = mock_library + mock_library.get_document_content.return_value = b"librarian content" + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + synthesis = Synthesis( + uri="urn:syn:1", + document_uri="urn:document:abc123" + ) + result = client.fetch_synthesis_content(synthesis, api=mock_api) + assert result == "librarian content" + + def test_fetch_synthesis_no_content_or_document(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + + synthesis = Synthesis(uri="urn:syn:1") + result = client.fetch_synthesis_content(synthesis, api=None) + assert result == "" + + def test_fetch_conclusion_inline(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + + conclusion = Conclusion(uri="urn:conc:1", answer="42") + result = client.fetch_conclusion_content(conclusion, api=None) + assert result == "42" + + def test_fetch_analysis_content_from_librarian(self): + mock_flow = MagicMock() + mock_api = MagicMock() + mock_library = MagicMock() + mock_api.library.return_value = mock_library + mock_library.get_document_content.side_effect = [ + b"thought content", + b"observation content", + ] + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + analysis = Analysis( + uri="urn:ana:1", + action="search", + thought_document_uri="urn:doc:thought", + observation_document_uri="urn:doc:obs", + ) + client.fetch_analysis_content(analysis, api=mock_api) + assert analysis.thought == "thought content" + assert analysis.observation == "observation content" + + def test_fetch_analysis_skips_when_inline_exists(self): + mock_flow = MagicMock() + mock_api = MagicMock() + + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + analysis = Analysis( + uri="urn:ana:1", + action="search", + thought="already have thought", + observation="already have observation", + thought_document_uri="urn:doc:thought", + observation_document_uri="urn:doc:obs", + ) + client.fetch_analysis_content(analysis, api=mock_api) + # Should not call librarian since inline content exists + mock_api.library.assert_not_called() + + +class TestExplainabilityClientDetectSessionType: + + def test_detect_agent_from_uri(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + assert client.detect_session_type("urn:trustgraph:agent:abc") == "agent" + + def test_detect_graphrag_from_uri(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + assert client.detect_session_type("urn:trustgraph:question:abc") == "graphrag" + + def test_detect_docrag_from_uri(self): + mock_flow = MagicMock() + client = ExplainabilityClient(mock_flow, retry_delay=0.0) + assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag" diff --git a/tests/unit/test_provenance/test_triples.py b/tests/unit/test_provenance/test_triples.py new file mode 100644 index 00000000..91074097 --- /dev/null +++ b/tests/unit/test_provenance/test_triples.py @@ -0,0 +1,785 @@ +""" +Tests for provenance triple builder functions (extraction-time and query-time). +""" + +import pytest +from unittest.mock import patch + +from trustgraph.schema import Triple, Term, IRI, LITERAL, TRIPLE + +from trustgraph.provenance.triples import ( + set_graph, + document_triples, + derived_entity_triples, + subgraph_provenance_triples, + question_triples, + exploration_triples, + focus_triples, + synthesis_triples, + docrag_question_triples, + docrag_exploration_triples, + docrag_synthesis_triples, +) + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT, + PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME, + DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR, + TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER, + TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH, + TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, + TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS, + TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, + TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_CONTENT, TG_DOCUMENT, + TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, + GRAPH_SOURCE, GRAPH_RETRIEVAL, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + """Find first triple matching predicate (and optionally subject).""" + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + """Find all triples matching predicate (and optionally subject).""" + return [ + t for t in triples + if t.p.iri == predicate and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + """Check if subject has rdf:type rdf_type.""" + for t in triples: + if (t.s.iri == subject and t.p.iri == RDF_TYPE + and t.o.type == IRI and t.o.iri == rdf_type): + return True + return False + + +# --------------------------------------------------------------------------- +# set_graph +# --------------------------------------------------------------------------- + +class TestSetGraph: + + def test_sets_graph_on_all_triples(self): + triples = [ + Triple( + s=Term(type=IRI, iri="urn:s1"), + p=Term(type=IRI, iri="urn:p1"), + o=Term(type=LITERAL, value="v1"), + ), + Triple( + s=Term(type=IRI, iri="urn:s2"), + p=Term(type=IRI, iri="urn:p2"), + o=Term(type=LITERAL, value="v2"), + ), + ] + result = set_graph(triples, GRAPH_RETRIEVAL) + assert len(result) == 2 + for t in result: + assert t.g == GRAPH_RETRIEVAL + + def test_does_not_modify_originals(self): + original = Triple( + s=Term(type=IRI, iri="urn:s"), + p=Term(type=IRI, iri="urn:p"), + o=Term(type=LITERAL, value="v"), + ) + result = set_graph([original], "urn:graph:test") + assert original.g is None + assert result[0].g == "urn:graph:test" + + def test_empty_list(self): + result = set_graph([], GRAPH_SOURCE) + assert result == [] + + def test_preserves_spo(self): + original = Triple( + s=Term(type=IRI, iri="urn:s"), + p=Term(type=IRI, iri="urn:p"), + o=Term(type=LITERAL, value="hello"), + ) + result = set_graph([original], "urn:g")[0] + assert result.s.iri == "urn:s" + assert result.p.iri == "urn:p" + assert result.o.value == "hello" + + +# --------------------------------------------------------------------------- +# document_triples +# --------------------------------------------------------------------------- + +class TestDocumentTriples: + + DOC_URI = "https://example.com/doc/abc" + + def test_minimal_document(self): + triples = document_triples(self.DOC_URI) + assert has_type(triples, self.DOC_URI, PROV_ENTITY) + assert has_type(triples, self.DOC_URI, TG_DOCUMENT_TYPE) + assert len(triples) == 2 + + def test_with_title(self): + triples = document_triples(self.DOC_URI, title="My Doc") + title_t = find_triple(triples, DC_TITLE) + assert title_t is not None + assert title_t.o.value == "My Doc" + # Title also creates an rdfs:label + label_t = find_triple(triples, RDFS_LABEL) + assert label_t is not None + assert label_t.o.value == "My Doc" + + def test_with_source(self): + triples = document_triples(self.DOC_URI, source="https://source.com/f.pdf") + source_t = find_triple(triples, DC_SOURCE) + assert source_t is not None + assert source_t.o.type == IRI + assert source_t.o.iri == "https://source.com/f.pdf" + + def test_with_date(self): + triples = document_triples(self.DOC_URI, date="2024-01-15") + date_t = find_triple(triples, DC_DATE) + assert date_t is not None + assert date_t.o.value == "2024-01-15" + + def test_with_creator(self): + triples = document_triples(self.DOC_URI, creator="Alice") + creator_t = find_triple(triples, DC_CREATOR) + assert creator_t is not None + assert creator_t.o.value == "Alice" + + def test_with_page_count(self): + triples = document_triples(self.DOC_URI, page_count=42) + pc_t = find_triple(triples, TG_PAGE_COUNT) + assert pc_t is not None + assert pc_t.o.value == "42" + + def test_with_page_count_zero(self): + triples = document_triples(self.DOC_URI, page_count=0) + pc_t = find_triple(triples, TG_PAGE_COUNT) + assert pc_t is not None + assert pc_t.o.value == "0" + + def test_with_mime_type(self): + triples = document_triples(self.DOC_URI, mime_type="application/pdf") + mt_t = find_triple(triples, TG_MIME_TYPE) + assert mt_t is not None + assert mt_t.o.value == "application/pdf" + + def test_all_metadata(self): + triples = document_triples( + self.DOC_URI, + title="Test", + source="https://s.com", + date="2024-01-01", + creator="Bob", + page_count=10, + mime_type="application/pdf", + ) + # 2 type triples + title + label + source + date + creator + page_count + mime_type + assert len(triples) == 9 + + def test_subject_is_doc_uri(self): + triples = document_triples(self.DOC_URI, title="T") + for t in triples: + assert t.s.iri == self.DOC_URI + + +# --------------------------------------------------------------------------- +# derived_entity_triples +# --------------------------------------------------------------------------- + +class TestDerivedEntityTriples: + + ENTITY_URI = "https://example.com/doc/abc/p1" + PARENT_URI = "https://example.com/doc/abc" + + def test_page_entity_has_page_type(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + page_number=1, + timestamp="2024-01-01T00:00:00Z", + ) + assert has_type(triples, self.ENTITY_URI, PROV_ENTITY) + assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE) + + def test_chunk_entity_has_chunk_type(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "chunker", "1.0", + chunk_index=0, + timestamp="2024-01-01T00:00:00Z", + ) + assert has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE) + + def test_no_specific_type_without_page_or_chunk(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "component", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + assert has_type(triples, self.ENTITY_URI, PROV_ENTITY) + assert not has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE) + assert not has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE) + + def test_was_derived_from_parent(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ENTITY_URI) + assert derived is not None + assert derived.o.iri == self.PARENT_URI + + def test_activity_created(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + # Entity was generated by an activity + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI) + assert gen is not None + act_uri = gen.o.iri + + # Activity has correct type and metadata + assert has_type(triples, act_uri, PROV_ACTIVITY) + + # Activity used the parent + used = find_triple(triples, PROV_USED, act_uri) + assert used is not None + assert used.o.iri == self.PARENT_URI + + # Activity has component version + version = find_triple(triples, TG_COMPONENT_VERSION, act_uri) + assert version is not None + assert version.o.value == "1.0" + + def test_agent_created(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + # Find the agent URI via wasAssociatedWith + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI) + act_uri = gen.o.iri + assoc = find_triple(triples, PROV_WAS_ASSOCIATED_WITH, act_uri) + assert assoc is not None + agt_uri = assoc.o.iri + + assert has_type(triples, agt_uri, PROV_AGENT) + label = find_triple(triples, RDFS_LABEL, agt_uri) + assert label is not None + assert label.o.value == "pdf-extractor" + + def test_timestamp_recorded(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + timestamp="2024-06-15T12:30:00Z", + ) + ts = find_triple(triples, PROV_STARTED_AT_TIME) + assert ts is not None + assert ts.o.value == "2024-06-15T12:30:00Z" + + def test_default_timestamp_generated(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + ) + ts = find_triple(triples, PROV_STARTED_AT_TIME) + assert ts is not None + assert len(ts.o.value) > 0 + + def test_optional_label(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + label="Page 1", + timestamp="2024-01-01T00:00:00Z", + ) + label = find_triple(triples, RDFS_LABEL, self.ENTITY_URI) + assert label is not None + assert label.o.value == "Page 1" + + def test_page_number_recorded(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "pdf-extractor", "1.0", + page_number=3, + timestamp="2024-01-01T00:00:00Z", + ) + pn = find_triple(triples, TG_PAGE_NUMBER, self.ENTITY_URI) + assert pn is not None + assert pn.o.value == "3" + + def test_chunk_metadata_recorded(self): + triples = derived_entity_triples( + self.ENTITY_URI, self.PARENT_URI, + "chunker", "2.0", + chunk_index=5, + char_offset=1000, + char_length=500, + chunk_size=512, + chunk_overlap=64, + timestamp="2024-01-01T00:00:00Z", + ) + ci = find_triple(triples, TG_CHUNK_INDEX, self.ENTITY_URI) + assert ci is not None and ci.o.value == "5" + + co = find_triple(triples, TG_CHAR_OFFSET, self.ENTITY_URI) + assert co is not None and co.o.value == "1000" + + cl = find_triple(triples, TG_CHAR_LENGTH, self.ENTITY_URI) + assert cl is not None and cl.o.value == "500" + + # chunk_size and chunk_overlap are on the activity, not the entity + cs = find_triple(triples, TG_CHUNK_SIZE) + assert cs is not None and cs.o.value == "512" + + ov = find_triple(triples, TG_CHUNK_OVERLAP) + assert ov is not None and ov.o.value == "64" + + +# --------------------------------------------------------------------------- +# subgraph_provenance_triples +# --------------------------------------------------------------------------- + +class TestSubgraphProvenanceTriples: + + SG_URI = "https://trustgraph.ai/subgraph/test-sg" + CHUNK_URI = "https://example.com/doc/abc/p1/c0" + + def _make_extracted_triple(self, s="urn:e:Alice", p="urn:r:knows", o="urn:e:Bob"): + return Triple( + s=Term(type=IRI, iri=s), + p=Term(type=IRI, iri=p), + o=Term(type=IRI, iri=o), + ) + + def test_contains_quoted_triples(self): + extracted = [self._make_extracted_triple()] + triples = subgraph_provenance_triples( + self.SG_URI, extracted, self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + contains = find_triples(triples, TG_CONTAINS, self.SG_URI) + assert len(contains) == 1 + assert contains[0].o.type == TRIPLE + assert contains[0].o.triple.s.iri == "urn:e:Alice" + assert contains[0].o.triple.p.iri == "urn:r:knows" + assert contains[0].o.triple.o.iri == "urn:e:Bob" + + def test_multiple_extracted_triples(self): + extracted = [ + self._make_extracted_triple("urn:e:A", "urn:r:x", "urn:e:B"), + self._make_extracted_triple("urn:e:C", "urn:r:y", "urn:e:D"), + self._make_extracted_triple("urn:e:E", "urn:r:z", "urn:e:F"), + ] + triples = subgraph_provenance_triples( + self.SG_URI, extracted, self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + contains = find_triples(triples, TG_CONTAINS, self.SG_URI) + assert len(contains) == 3 + + def test_empty_extracted_triples(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + contains = find_triples(triples, TG_CONTAINS, self.SG_URI) + assert len(contains) == 0 + # Should still have subgraph provenance metadata + assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE) + + def test_subgraph_has_correct_types(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + assert has_type(triples, self.SG_URI, PROV_ENTITY) + assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE) + + def test_derived_from_chunk(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SG_URI) + assert derived is not None + assert derived.o.iri == self.CHUNK_URI + + def test_activity_and_agent(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.SG_URI) + assert gen is not None + act_uri = gen.o.iri + + assert has_type(triples, act_uri, PROV_ACTIVITY) + + used = find_triple(triples, PROV_USED, act_uri) + assert used is not None + assert used.o.iri == self.CHUNK_URI + + version = find_triple(triples, TG_COMPONENT_VERSION, act_uri) + assert version is not None + assert version.o.value == "1.0" + + def test_optional_llm_model(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + llm_model="claude-3-opus", + timestamp="2024-01-01T00:00:00Z", + ) + llm = find_triple(triples, TG_LLM_MODEL) + assert llm is not None + assert llm.o.value == "claude-3-opus" + + def test_no_llm_model_when_omitted(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + timestamp="2024-01-01T00:00:00Z", + ) + llm = find_triple(triples, TG_LLM_MODEL) + assert llm is None + + def test_optional_ontology(self): + triples = subgraph_provenance_triples( + self.SG_URI, [], self.CHUNK_URI, + "kg-extractor", "1.0", + ontology_uri="https://example.com/ontology/v1", + timestamp="2024-01-01T00:00:00Z", + ) + ont = find_triple(triples, TG_ONTOLOGY) + assert ont is not None + assert ont.o.type == IRI + assert ont.o.iri == "https://example.com/ontology/v1" + + +# --------------------------------------------------------------------------- +# GraphRAG query-time triples +# --------------------------------------------------------------------------- + +class TestQuestionTriples: + + Q_URI = "urn:trustgraph:question:test-session" + + def test_question_types(self): + triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z") + assert has_type(triples, self.Q_URI, PROV_ACTIVITY) + assert has_type(triples, self.Q_URI, TG_QUESTION) + assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION) + + def test_question_query_text(self): + triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z") + query = find_triple(triples, TG_QUERY, self.Q_URI) + assert query is not None + assert query.o.value == "What is AI?" + + def test_question_timestamp(self): + triples = question_triples(self.Q_URI, "Q", "2024-06-15T10:00:00Z") + ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI) + assert ts is not None + assert ts.o.value == "2024-06-15T10:00:00Z" + + def test_question_default_timestamp(self): + triples = question_triples(self.Q_URI, "Q") + ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI) + assert ts is not None + assert len(ts.o.value) > 0 + + def test_question_label(self): + triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z") + label = find_triple(triples, RDFS_LABEL, self.Q_URI) + assert label is not None + assert label.o.value == "GraphRAG Question" + + def test_question_triple_count(self): + triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z") + assert len(triples) == 6 + + +class TestExplorationTriples: + + EXP_URI = "urn:trustgraph:prov:exploration:test-session" + Q_URI = "urn:trustgraph:question:test-session" + + def test_exploration_types(self): + triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) + assert has_type(triples, self.EXP_URI, PROV_ENTITY) + assert has_type(triples, self.EXP_URI, TG_EXPLORATION) + + def test_exploration_generated_by_question(self): + triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI) + assert gen is not None + assert gen.o.iri == self.Q_URI + + def test_exploration_edge_count(self): + triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) + ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI) + assert ec is not None + assert ec.o.value == "15" + + def test_exploration_zero_edges(self): + triples = exploration_triples(self.EXP_URI, self.Q_URI, 0) + ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI) + assert ec is not None + assert ec.o.value == "0" + + def test_exploration_triple_count(self): + triples = exploration_triples(self.EXP_URI, self.Q_URI, 10) + assert len(triples) == 5 + + +class TestFocusTriples: + + FOC_URI = "urn:trustgraph:prov:focus:test-session" + EXP_URI = "urn:trustgraph:prov:exploration:test-session" + SESSION_ID = "test-session" + + def test_focus_types(self): + triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID) + assert has_type(triples, self.FOC_URI, PROV_ENTITY) + assert has_type(triples, self.FOC_URI, TG_FOCUS) + + def test_focus_derived_from_exploration(self): + triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FOC_URI) + assert derived is not None + assert derived.o.iri == self.EXP_URI + + def test_focus_no_edges(self): + triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID) + selected = find_triples(triples, TG_SELECTED_EDGE) + assert len(selected) == 0 + + def test_focus_with_edges_and_reasoning(self): + edges = [ + { + "edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"), + "reasoning": "Alice is connected to Bob", + }, + { + "edge": ("urn:e:Bob", "urn:r:worksAt", "urn:e:Acme"), + "reasoning": "Bob works at Acme", + }, + ] + triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID) + + # Two selectedEdge links + selected = find_triples(triples, TG_SELECTED_EDGE, self.FOC_URI) + assert len(selected) == 2 + + # Each edge selection has a quoted triple + edge_triples = find_triples(triples, TG_EDGE) + assert len(edge_triples) == 2 + for et in edge_triples: + assert et.o.type == TRIPLE + + # Each edge selection has reasoning + reasoning_triples = find_triples(triples, TG_REASONING) + assert len(reasoning_triples) == 2 + + def test_focus_edge_without_reasoning(self): + edges = [ + {"edge": ("urn:e:A", "urn:r:x", "urn:e:B"), "reasoning": ""}, + ] + triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID) + reasoning = find_triples(triples, TG_REASONING) + assert len(reasoning) == 0 + + def test_focus_edge_without_edge_data(self): + edges = [ + {"edge": None, "reasoning": "some reasoning"}, + ] + triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID) + selected = find_triples(triples, TG_SELECTED_EDGE) + assert len(selected) == 0 + + def test_focus_quoted_triple_content(self): + edges = [ + { + "edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"), + "reasoning": "test", + }, + ] + triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID) + edge_t = find_triple(triples, TG_EDGE) + qt = edge_t.o.triple + assert qt.s.iri == "urn:e:Alice" + assert qt.p.iri == "urn:r:knows" + assert qt.o.iri == "urn:e:Bob" + + +class TestSynthesisTriples: + + SYN_URI = "urn:trustgraph:prov:synthesis:test-session" + FOC_URI = "urn:trustgraph:prov:focus:test-session" + + def test_synthesis_types(self): + triples = synthesis_triples(self.SYN_URI, self.FOC_URI) + assert has_type(triples, self.SYN_URI, PROV_ENTITY) + assert has_type(triples, self.SYN_URI, TG_SYNTHESIS) + + def test_synthesis_derived_from_focus(self): + triples = synthesis_triples(self.SYN_URI, self.FOC_URI) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI) + assert derived is not None + assert derived.o.iri == self.FOC_URI + + def test_synthesis_with_inline_content(self): + triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42") + content = find_triple(triples, TG_CONTENT, self.SYN_URI) + assert content is not None + assert content.o.value == "The answer is 42" + + def test_synthesis_with_document_reference(self): + triples = synthesis_triples( + self.SYN_URI, self.FOC_URI, + document_id="urn:trustgraph:question:abc/answer", + ) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert doc is not None + assert doc.o.type == IRI + assert doc.o.iri == "urn:trustgraph:question:abc/answer" + + def test_synthesis_document_takes_precedence(self): + """When both document_id and answer_text are provided, document_id wins.""" + triples = synthesis_triples( + self.SYN_URI, self.FOC_URI, + answer_text="inline", + document_id="urn:doc:123", + ) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert doc is not None + content = find_triple(triples, TG_CONTENT, self.SYN_URI) + assert content is None + + def test_synthesis_no_content_or_document(self): + triples = synthesis_triples(self.SYN_URI, self.FOC_URI) + content = find_triple(triples, TG_CONTENT, self.SYN_URI) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert content is None + assert doc is None + + +# --------------------------------------------------------------------------- +# DocumentRAG query-time triples +# --------------------------------------------------------------------------- + +class TestDocRagQuestionTriples: + + Q_URI = "urn:trustgraph:docrag:test-session" + + def test_docrag_question_types(self): + triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z") + assert has_type(triples, self.Q_URI, PROV_ACTIVITY) + assert has_type(triples, self.Q_URI, TG_QUESTION) + assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION) + + def test_docrag_question_label(self): + triples = docrag_question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z") + label = find_triple(triples, RDFS_LABEL, self.Q_URI) + assert label.o.value == "DocumentRAG Question" + + def test_docrag_question_query_text(self): + triples = docrag_question_triples(self.Q_URI, "search query", "2024-01-01T00:00:00Z") + query = find_triple(triples, TG_QUERY, self.Q_URI) + assert query.o.value == "search query" + + +class TestDocRagExplorationTriples: + + EXP_URI = "urn:trustgraph:docrag:test/exploration" + Q_URI = "urn:trustgraph:docrag:test" + + def test_docrag_exploration_types(self): + triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5) + assert has_type(triples, self.EXP_URI, PROV_ENTITY) + assert has_type(triples, self.EXP_URI, TG_EXPLORATION) + + def test_docrag_exploration_generated_by(self): + triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI) + assert gen.o.iri == self.Q_URI + + def test_docrag_exploration_chunk_count(self): + triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7) + cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI) + assert cc.o.value == "7" + + def test_docrag_exploration_without_chunk_ids(self): + triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3) + chunks = find_triples(triples, TG_SELECTED_CHUNK) + assert len(chunks) == 0 + + def test_docrag_exploration_with_chunk_ids(self): + chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"] + triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids) + chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI) + assert len(chunks) == 3 + chunk_uris = {t.o.iri for t in chunks} + assert chunk_uris == set(chunk_ids) + + +class TestDocRagSynthesisTriples: + + SYN_URI = "urn:trustgraph:docrag:test/synthesis" + EXP_URI = "urn:trustgraph:docrag:test/exploration" + + def test_docrag_synthesis_types(self): + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI) + assert has_type(triples, self.SYN_URI, PROV_ENTITY) + assert has_type(triples, self.SYN_URI, TG_SYNTHESIS) + + def test_docrag_synthesis_derived_from_exploration(self): + """DocRAG skips the focus step — synthesis derives from exploration.""" + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI) + assert derived.o.iri == self.EXP_URI + + def test_docrag_synthesis_with_inline(self): + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer") + content = find_triple(triples, TG_CONTENT, self.SYN_URI) + assert content.o.value == "answer" + + def test_docrag_synthesis_with_document(self): + triples = docrag_synthesis_triples( + self.SYN_URI, self.EXP_URI, document_id="urn:doc:ans" + ) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert doc.o.iri == "urn:doc:ans" + content = find_triple(triples, TG_CONTENT, self.SYN_URI) + assert content is None diff --git a/tests/unit/test_provenance/test_uris.py b/tests/unit/test_provenance/test_uris.py new file mode 100644 index 00000000..0e69734c --- /dev/null +++ b/tests/unit/test_provenance/test_uris.py @@ -0,0 +1,292 @@ +""" +Tests for provenance URI generation functions. +""" + +import pytest +from unittest.mock import patch + +from trustgraph.provenance.uris import ( + TRUSTGRAPH_BASE, + _encode_id, + document_uri, + page_uri, + chunk_uri_from_page, + chunk_uri_from_doc, + activity_uri, + subgraph_uri, + agent_uri, + question_uri, + exploration_uri, + focus_uri, + synthesis_uri, + edge_selection_uri, + agent_session_uri, + agent_iteration_uri, + agent_final_uri, + docrag_question_uri, + docrag_exploration_uri, + docrag_synthesis_uri, +) + + +class TestEncodeId: + """Tests for the _encode_id helper.""" + + def test_plain_string(self): + assert _encode_id("abc123") == "abc123" + + def test_string_with_spaces(self): + assert _encode_id("hello world") == "hello%20world" + + def test_string_with_slashes(self): + assert _encode_id("a/b/c") == "a%2Fb%2Fc" + + def test_integer_input(self): + assert _encode_id(42) == "42" + + def test_empty_string(self): + assert _encode_id("") == "" + + def test_special_characters(self): + result = _encode_id("name@domain.com") + assert "@" not in result or result == "name%40domain.com" + + +class TestDocumentUris: + """Tests for document, page, and chunk URI generation.""" + + def test_document_uri_passthrough(self): + iri = "https://example.com/doc/123" + assert document_uri(iri) == iri + + def test_page_uri_format(self): + result = page_uri("https://example.com/doc/123", 5) + assert result == "https://example.com/doc/123/p5" + + def test_page_uri_page_zero(self): + result = page_uri("https://example.com/doc/123", 0) + assert result == "https://example.com/doc/123/p0" + + def test_chunk_uri_from_page_format(self): + result = chunk_uri_from_page("https://example.com/doc/123", 2, 3) + assert result == "https://example.com/doc/123/p2/c3" + + def test_chunk_uri_from_doc_format(self): + result = chunk_uri_from_doc("https://example.com/doc/123", 7) + assert result == "https://example.com/doc/123/c7" + + def test_page_uri_preserves_doc_iri(self): + doc = "urn:isbn:978-3-16-148410-0" + result = page_uri(doc, 1) + assert result.startswith(doc) + + def test_chunk_from_page_hierarchy(self): + """Chunk URI should contain both page and chunk identifiers.""" + result = chunk_uri_from_page("https://example.com/doc", 3, 5) + assert "/p3/" in result + assert result.endswith("/c5") + + +class TestActivityAndSubgraphUris: + """Tests for activity_uri, subgraph_uri, and agent_uri.""" + + def test_activity_uri_with_id(self): + result = activity_uri("my-activity-id") + assert result == f"{TRUSTGRAPH_BASE}/activity/my-activity-id" + + def test_activity_uri_auto_generates_uuid(self): + result = activity_uri() + assert result.startswith(f"{TRUSTGRAPH_BASE}/activity/") + # UUID part should be non-empty + uuid_part = result.split("/activity/")[1] + assert len(uuid_part) > 0 + + def test_activity_uri_unique_uuids(self): + r1 = activity_uri() + r2 = activity_uri() + assert r1 != r2 + + def test_activity_uri_encodes_special_chars(self): + result = activity_uri("id with spaces") + assert "id%20with%20spaces" in result + + def test_subgraph_uri_with_id(self): + result = subgraph_uri("sg-123") + assert result == f"{TRUSTGRAPH_BASE}/subgraph/sg-123" + + def test_subgraph_uri_auto_generates_uuid(self): + result = subgraph_uri() + assert result.startswith(f"{TRUSTGRAPH_BASE}/subgraph/") + uuid_part = result.split("/subgraph/")[1] + assert len(uuid_part) > 0 + + def test_subgraph_uri_unique_uuids(self): + r1 = subgraph_uri() + r2 = subgraph_uri() + assert r1 != r2 + + def test_agent_uri_format(self): + result = agent_uri("pdf-extractor") + assert result == f"{TRUSTGRAPH_BASE}/agent/pdf-extractor" + + def test_agent_uri_encodes_special_chars(self): + result = agent_uri("my component") + assert "my%20component" in result + + +class TestGraphRagQueryUris: + """Tests for GraphRAG query-time provenance URIs.""" + + FIXED_UUID = "550e8400-e29b-41d4-a716-446655440000" + + def test_question_uri_with_session_id(self): + result = question_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:question:{self.FIXED_UUID}" + + def test_question_uri_auto_generates(self): + result = question_uri() + assert result.startswith("urn:trustgraph:question:") + uuid_part = result.split("urn:trustgraph:question:")[1] + assert len(uuid_part) > 0 + + def test_question_uri_unique(self): + r1 = question_uri() + r2 = question_uri() + assert r1 != r2 + + def test_exploration_uri_format(self): + result = exploration_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:prov:exploration:{self.FIXED_UUID}" + + def test_focus_uri_format(self): + result = focus_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:prov:focus:{self.FIXED_UUID}" + + def test_synthesis_uri_format(self): + result = synthesis_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:prov:synthesis:{self.FIXED_UUID}" + + def test_edge_selection_uri_format(self): + result = edge_selection_uri(self.FIXED_UUID, 3) + assert result == f"urn:trustgraph:prov:edge:{self.FIXED_UUID}:3" + + def test_edge_selection_uri_zero_index(self): + result = edge_selection_uri(self.FIXED_UUID, 0) + assert result.endswith(":0") + + def test_session_uris_share_session_id(self): + """All URIs for a session should contain the same session ID.""" + sid = self.FIXED_UUID + q = question_uri(sid) + e = exploration_uri(sid) + f = focus_uri(sid) + s = synthesis_uri(sid) + for uri in [q, e, f, s]: + assert sid in uri + + +class TestAgentProvenanceUris: + """Tests for agent provenance URIs.""" + + FIXED_UUID = "661e8400-e29b-41d4-a716-446655440000" + + def test_agent_session_uri_with_id(self): + result = agent_session_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}" + + def test_agent_session_uri_auto_generates(self): + result = agent_session_uri() + assert result.startswith("urn:trustgraph:agent:") + + def test_agent_session_uri_unique(self): + r1 = agent_session_uri() + r2 = agent_session_uri() + assert r1 != r2 + + def test_agent_iteration_uri_format(self): + result = agent_iteration_uri(self.FIXED_UUID, 1) + assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/i1" + + def test_agent_iteration_uri_numbering(self): + r1 = agent_iteration_uri(self.FIXED_UUID, 1) + r2 = agent_iteration_uri(self.FIXED_UUID, 2) + assert r1 != r2 + assert r1.endswith("/i1") + assert r2.endswith("/i2") + + def test_agent_final_uri_format(self): + result = agent_final_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/final" + + def test_agent_uris_share_session_id(self): + sid = self.FIXED_UUID + session = agent_session_uri(sid) + iteration = agent_iteration_uri(sid, 1) + final = agent_final_uri(sid) + for uri in [session, iteration, final]: + assert sid in uri + + +class TestDocRagProvenanceUris: + """Tests for Document RAG provenance URIs.""" + + FIXED_UUID = "772e8400-e29b-41d4-a716-446655440000" + + def test_docrag_question_uri_with_id(self): + result = docrag_question_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}" + + def test_docrag_question_uri_auto_generates(self): + result = docrag_question_uri() + assert result.startswith("urn:trustgraph:docrag:") + + def test_docrag_question_uri_unique(self): + r1 = docrag_question_uri() + r2 = docrag_question_uri() + assert r1 != r2 + + def test_docrag_exploration_uri_format(self): + result = docrag_exploration_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/exploration" + + def test_docrag_synthesis_uri_format(self): + result = docrag_synthesis_uri(self.FIXED_UUID) + assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/synthesis" + + def test_docrag_uris_share_session_id(self): + sid = self.FIXED_UUID + q = docrag_question_uri(sid) + e = docrag_exploration_uri(sid) + s = docrag_synthesis_uri(sid) + for uri in [q, e, s]: + assert sid in uri + + +class TestUriNamespaceIsolation: + """Verify that different provenance types use distinct URI namespaces.""" + + FIXED_UUID = "883e8400-e29b-41d4-a716-446655440000" + + def test_graphrag_vs_agent_namespace(self): + graphrag = question_uri(self.FIXED_UUID) + agent = agent_session_uri(self.FIXED_UUID) + assert graphrag != agent + assert "question" in graphrag + assert "agent" in agent + + def test_graphrag_vs_docrag_namespace(self): + graphrag = question_uri(self.FIXED_UUID) + docrag = docrag_question_uri(self.FIXED_UUID) + assert graphrag != docrag + + def test_agent_vs_docrag_namespace(self): + agent = agent_session_uri(self.FIXED_UUID) + docrag = docrag_question_uri(self.FIXED_UUID) + assert agent != docrag + + def test_extraction_vs_query_namespace(self): + """Extraction URIs use https://, query URIs use urn:.""" + ext = activity_uri(self.FIXED_UUID) + query = question_uri(self.FIXED_UUID) + assert ext.startswith("https://") + assert query.startswith("urn:") diff --git a/tests/unit/test_provenance/test_vocabulary.py b/tests/unit/test_provenance/test_vocabulary.py new file mode 100644 index 00000000..a3c644e8 --- /dev/null +++ b/tests/unit/test_provenance/test_vocabulary.py @@ -0,0 +1,124 @@ +""" +Tests for provenance vocabulary bootstrap. +""" + +import pytest + +from trustgraph.schema import Triple, Term, IRI, LITERAL + +from trustgraph.provenance.vocabulary import ( + get_vocabulary_triples, + PROV_CLASS_LABELS, + PROV_PREDICATE_LABELS, + DC_PREDICATE_LABELS, + SCHEMA_LABELS, + SKOS_LABELS, + TG_CLASS_LABELS, + TG_PREDICATE_LABELS, +) + +from trustgraph.provenance.namespaces import ( + RDFS_LABEL, + PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT, + PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME, + DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR, + TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, +) + + +class TestVocabularyTriples: + """Tests for the vocabulary bootstrap function.""" + + def test_returns_list_of_triples(self): + result = get_vocabulary_triples() + assert isinstance(result, list) + assert len(result) > 0 + for t in result: + assert isinstance(t, Triple) + + def test_all_triples_are_label_triples(self): + """Every vocabulary triple should use rdfs:label as predicate.""" + for t in get_vocabulary_triples(): + assert t.p.type == IRI + assert t.p.iri == RDFS_LABEL + + def test_all_subjects_are_iris(self): + for t in get_vocabulary_triples(): + assert t.s.type == IRI + assert len(t.s.iri) > 0 + + def test_all_objects_are_literals(self): + for t in get_vocabulary_triples(): + assert t.o.type == LITERAL + assert len(t.o.value) > 0 + + def test_no_duplicate_subjects(self): + subjects = [t.s.iri for t in get_vocabulary_triples()] + assert len(subjects) == len(set(subjects)) + + def test_includes_prov_classes(self): + subjects = {t.s.iri for t in get_vocabulary_triples()} + assert PROV_ENTITY in subjects + assert PROV_ACTIVITY in subjects + assert PROV_AGENT in subjects + + def test_includes_prov_predicates(self): + subjects = {t.s.iri for t in get_vocabulary_triples()} + assert PROV_WAS_DERIVED_FROM in subjects + assert PROV_WAS_GENERATED_BY in subjects + assert PROV_USED in subjects + assert PROV_WAS_ASSOCIATED_WITH in subjects + assert PROV_STARTED_AT_TIME in subjects + + def test_includes_dc_predicates(self): + subjects = {t.s.iri for t in get_vocabulary_triples()} + assert DC_TITLE in subjects + assert DC_SOURCE in subjects + assert DC_DATE in subjects + assert DC_CREATOR in subjects + + def test_includes_tg_classes(self): + subjects = {t.s.iri for t in get_vocabulary_triples()} + assert TG_DOCUMENT_TYPE in subjects + assert TG_PAGE_TYPE in subjects + assert TG_CHUNK_TYPE in subjects + assert TG_SUBGRAPH_TYPE in subjects + + def test_component_lists_sum_to_total(self): + total = get_vocabulary_triples() + components = ( + PROV_CLASS_LABELS + + PROV_PREDICATE_LABELS + + DC_PREDICATE_LABELS + + SCHEMA_LABELS + + SKOS_LABELS + + TG_CLASS_LABELS + + TG_PREDICATE_LABELS + ) + assert len(total) == len(components) + + def test_idempotent(self): + """Calling twice should return equivalent triples.""" + r1 = get_vocabulary_triples() + r2 = get_vocabulary_triples() + assert len(r1) == len(r2) + for t1, t2 in zip(r1, r2): + assert t1.s.iri == t2.s.iri + assert t1.o.value == t2.o.value + + +class TestNamespaceConstants: + """Verify namespace constants are well-formed IRIs.""" + + def test_prov_namespace_prefix(self): + assert PROV_ENTITY.startswith("http://www.w3.org/ns/prov#") + + def test_dc_namespace_prefix(self): + assert DC_TITLE.startswith("http://purl.org/dc/elements/1.1/") + + def test_tg_namespace_prefix(self): + assert TG_DOCUMENT_TYPE.startswith("https://trustgraph.ai/ns/") + + def test_rdfs_label_iri(self): + assert RDFS_LABEL == "http://www.w3.org/2000/01/rdf-schema#label" diff --git a/tests/unit/test_rdf/__init__.py b/tests/unit/test_rdf/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_rdf/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_rdf/test_rdf_primitives.py b/tests/unit/test_rdf/test_rdf_primitives.py new file mode 100644 index 00000000..2498677b --- /dev/null +++ b/tests/unit/test_rdf/test_rdf_primitives.py @@ -0,0 +1,309 @@ +""" +Tests for RDF 1.2 type system primitives: Term dataclass (IRI, blank node, +typed literal, language-tagged literal, quoted triple), Triple/Quad dataclass +with named graph support, and the knowledge/defs helper types. +""" + +import pytest + +from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE + + +# --------------------------------------------------------------------------- +# Type constants +# --------------------------------------------------------------------------- + +class TestTypeConstants: + + def test_iri_constant(self): + assert IRI == "i" + + def test_blank_constant(self): + assert BLANK == "b" + + def test_literal_constant(self): + assert LITERAL == "l" + + def test_triple_constant(self): + assert TRIPLE == "t" + + def test_constants_are_distinct(self): + vals = {IRI, BLANK, LITERAL, TRIPLE} + assert len(vals) == 4 + + +# --------------------------------------------------------------------------- +# IRI terms +# --------------------------------------------------------------------------- + +class TestIriTerm: + + def test_create_iri(self): + t = Term(type=IRI, iri="http://example.org/Alice") + assert t.type == IRI + assert t.iri == "http://example.org/Alice" + + def test_iri_defaults_empty(self): + t = Term(type=IRI) + assert t.iri == "" + + def test_iri_with_fragment(self): + t = Term(type=IRI, iri="http://example.org/ontology#Person") + assert "#Person" in t.iri + + def test_iri_with_unicode(self): + t = Term(type=IRI, iri="http://example.org/概念") + assert "概念" in t.iri + + def test_iri_other_fields_default(self): + t = Term(type=IRI, iri="http://example.org/x") + assert t.id == "" + assert t.value == "" + assert t.datatype == "" + assert t.language == "" + assert t.triple is None + + +# --------------------------------------------------------------------------- +# Blank node terms +# --------------------------------------------------------------------------- + +class TestBlankNodeTerm: + + def test_create_blank_node(self): + t = Term(type=BLANK, id="_:b0") + assert t.type == BLANK + assert t.id == "_:b0" + + def test_blank_node_defaults_empty(self): + t = Term(type=BLANK) + assert t.id == "" + + def test_blank_node_arbitrary_id(self): + t = Term(type=BLANK, id="node-abc-123") + assert t.id == "node-abc-123" + + +# --------------------------------------------------------------------------- +# Typed literals (XSD datatypes) +# --------------------------------------------------------------------------- + +class TestTypedLiteral: + + def test_plain_literal(self): + t = Term(type=LITERAL, value="hello") + assert t.type == LITERAL + assert t.value == "hello" + assert t.datatype == "" + assert t.language == "" + + def test_xsd_integer(self): + t = Term( + type=LITERAL, value="42", + datatype="http://www.w3.org/2001/XMLSchema#integer", + ) + assert t.value == "42" + assert "integer" in t.datatype + + def test_xsd_boolean(self): + t = Term( + type=LITERAL, value="true", + datatype="http://www.w3.org/2001/XMLSchema#boolean", + ) + assert t.datatype.endswith("#boolean") + + def test_xsd_date(self): + t = Term( + type=LITERAL, value="2026-03-13", + datatype="http://www.w3.org/2001/XMLSchema#date", + ) + assert t.value == "2026-03-13" + assert t.datatype.endswith("#date") + + def test_xsd_double(self): + t = Term( + type=LITERAL, value="3.14", + datatype="http://www.w3.org/2001/XMLSchema#double", + ) + assert t.datatype.endswith("#double") + + def test_empty_value_literal(self): + t = Term(type=LITERAL, value="") + assert t.value == "" + + +# --------------------------------------------------------------------------- +# Language-tagged literals +# --------------------------------------------------------------------------- + +class TestLanguageTaggedLiteral: + + def test_english_tag(self): + t = Term(type=LITERAL, value="hello", language="en") + assert t.language == "en" + assert t.datatype == "" + + def test_french_tag(self): + t = Term(type=LITERAL, value="bonjour", language="fr") + assert t.language == "fr" + + def test_bcp47_subtag(self): + t = Term(type=LITERAL, value="colour", language="en-GB") + assert t.language == "en-GB" + + def test_language_and_datatype_mutually_exclusive(self): + """Both can be set on the dataclass, but semantically only one should be used.""" + t = Term(type=LITERAL, value="x", language="en", + datatype="http://www.w3.org/2001/XMLSchema#string") + # Dataclass allows both — translators should respect mutual exclusivity + assert t.language == "en" + assert t.datatype != "" + + +# --------------------------------------------------------------------------- +# Quoted triples (RDF-star) +# --------------------------------------------------------------------------- + +class TestQuotedTriple: + + def test_term_with_nested_triple(self): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/Alice"), + p=Term(type=IRI, iri="http://xmlns.com/foaf/0.1/knows"), + o=Term(type=IRI, iri="http://example.org/Bob"), + ) + qt = Term(type=TRIPLE, triple=inner) + assert qt.type == TRIPLE + assert qt.triple is inner + assert qt.triple.s.iri == "http://example.org/Alice" + + def test_quoted_triple_as_object(self): + """A triple whose object is a quoted triple (RDF-star).""" + inner = Triple( + s=Term(type=IRI, iri="http://example.org/Hope"), + p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"), + o=Term(type=LITERAL, value="A feeling of expectation"), + ) + outer = Triple( + s=Term(type=IRI, iri="urn:subgraph:123"), + p=Term(type=IRI, iri="http://trustgraph.ai/tg/contains"), + o=Term(type=TRIPLE, triple=inner), + ) + assert outer.o.type == TRIPLE + assert outer.o.triple.o.value == "A feeling of expectation" + + def test_quoted_triple_none(self): + t = Term(type=TRIPLE, triple=None) + assert t.triple is None + + +# --------------------------------------------------------------------------- +# Triple / Quad (named graph) +# --------------------------------------------------------------------------- + +class TestTripleQuad: + + def test_default_graph_is_none(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="val"), + ) + assert t.g is None + + def test_named_graph(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="val"), + g="urn:graph:source", + ) + assert t.g == "urn:graph:source" + + def test_empty_string_graph(self): + t = Triple(g="") + assert t.g == "" + + def test_triple_with_all_none_terms(self): + t = Triple() + assert t.s is None + assert t.p is None + assert t.o is None + assert t.g is None + + def test_triple_equality(self): + """Dataclass equality based on field values.""" + t1 = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/B"), + o=Term(type=LITERAL, value="C"), + ) + t2 = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/B"), + o=Term(type=LITERAL, value="C"), + ) + assert t1 == t2 + + +# --------------------------------------------------------------------------- +# knowledge/defs helper types +# --------------------------------------------------------------------------- + +class TestKnowledgeDefs: + + def test_uri_type(self): + from trustgraph.knowledge.defs import Uri + u = Uri("http://example.org/x") + assert u.is_uri() is True + assert u.is_literal() is False + assert u.is_triple() is False + assert str(u) == "http://example.org/x" + + def test_literal_type(self): + from trustgraph.knowledge.defs import Literal + l = Literal("hello world") + assert l.is_uri() is False + assert l.is_literal() is True + assert l.is_triple() is False + assert str(l) == "hello world" + + def test_quoted_triple_type(self): + from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal + qt = QuotedTriple( + s=Uri("http://example.org/s"), + p=Uri("http://example.org/p"), + o=Literal("val"), + ) + assert qt.is_uri() is False + assert qt.is_literal() is False + assert qt.is_triple() is True + assert qt.s == "http://example.org/s" + assert qt.o == "val" + + def test_quoted_triple_repr(self): + from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal + qt = QuotedTriple( + s=Uri("http://example.org/A"), + p=Uri("http://example.org/B"), + o=Literal("C"), + ) + r = repr(qt) + assert "<<" in r + assert ">>" in r + assert "http://example.org/A" in r + + def test_quoted_triple_nested(self): + """QuotedTriple can contain another QuotedTriple as object.""" + from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal + inner = QuotedTriple( + s=Uri("http://example.org/s"), + p=Uri("http://example.org/p"), + o=Literal("v"), + ) + outer = QuotedTriple( + s=Uri("http://example.org/s2"), + p=Uri("http://example.org/p2"), + o=inner, + ) + assert outer.o.is_triple() is True diff --git a/tests/unit/test_rdf/test_rdf_storage_helpers.py b/tests/unit/test_rdf/test_rdf_storage_helpers.py new file mode 100644 index 00000000..7bd1807a --- /dev/null +++ b/tests/unit/test_rdf/test_rdf_storage_helpers.py @@ -0,0 +1,217 @@ +""" +Tests for RDF storage helper functions used by the Cassandra triple writer: +serialize_triple, get_term_value, get_term_otype, get_term_dtype, get_term_lang. +""" + +import json +import pytest + +from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE +from trustgraph.storage.triples.cassandra.write import ( + serialize_triple, + get_term_value, + get_term_otype, + get_term_dtype, + get_term_lang, +) + + +# --------------------------------------------------------------------------- +# get_term_otype — maps Term.type to storage object type code +# --------------------------------------------------------------------------- + +class TestGetTermOtype: + + def test_iri_maps_to_u(self): + assert get_term_otype(Term(type=IRI, iri="http://x")) == "u" + + def test_blank_maps_to_u(self): + assert get_term_otype(Term(type=BLANK, id="_:b0")) == "u" + + def test_literal_maps_to_l(self): + assert get_term_otype(Term(type=LITERAL, value="x")) == "l" + + def test_triple_maps_to_t(self): + assert get_term_otype(Term(type=TRIPLE)) == "t" + + def test_none_defaults_to_u(self): + assert get_term_otype(None) == "u" + + def test_unknown_type_defaults_to_u(self): + assert get_term_otype(Term(type="z")) == "u" + + +# --------------------------------------------------------------------------- +# get_term_dtype — extracts XSD datatype from literals +# --------------------------------------------------------------------------- + +class TestGetTermDtype: + + def test_literal_with_datatype(self): + t = Term(type=LITERAL, value="42", + datatype="http://www.w3.org/2001/XMLSchema#integer") + assert get_term_dtype(t) == "http://www.w3.org/2001/XMLSchema#integer" + + def test_literal_without_datatype(self): + t = Term(type=LITERAL, value="hello") + assert get_term_dtype(t) == "" + + def test_iri_returns_empty(self): + assert get_term_dtype(Term(type=IRI, iri="http://x")) == "" + + def test_none_returns_empty(self): + assert get_term_dtype(None) == "" + + +# --------------------------------------------------------------------------- +# get_term_lang — extracts language tag from literals +# --------------------------------------------------------------------------- + +class TestGetTermLang: + + def test_literal_with_language(self): + t = Term(type=LITERAL, value="bonjour", language="fr") + assert get_term_lang(t) == "fr" + + def test_literal_without_language(self): + t = Term(type=LITERAL, value="hello") + assert get_term_lang(t) == "" + + def test_iri_returns_empty(self): + assert get_term_lang(Term(type=IRI, iri="http://x")) == "" + + def test_none_returns_empty(self): + assert get_term_lang(None) == "" + + def test_bcp47_subtag_preserved(self): + t = Term(type=LITERAL, value="colour", language="en-GB") + assert get_term_lang(t) == "en-GB" + + +# --------------------------------------------------------------------------- +# get_term_value — extracts string value from any Term +# --------------------------------------------------------------------------- + +class TestGetTermValue: + + def test_iri_returns_iri(self): + t = Term(type=IRI, iri="http://example.org/Alice") + assert get_term_value(t) == "http://example.org/Alice" + + def test_literal_returns_value(self): + t = Term(type=LITERAL, value="hello") + assert get_term_value(t) == "hello" + + def test_blank_returns_id(self): + t = Term(type=BLANK, id="_:b0") + assert get_term_value(t) == "_:b0" + + def test_none_returns_none(self): + assert get_term_value(None) is None + + def test_triple_returns_serialized_json(self): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="val"), + ) + t = Term(type=TRIPLE, triple=inner) + result = get_term_value(t) + parsed = json.loads(result) + assert parsed["s"]["type"] == "i" + assert parsed["s"]["iri"] == "http://example.org/s" + assert parsed["o"]["value"] == "val" + + +# --------------------------------------------------------------------------- +# serialize_triple — full Triple → JSON serialization +# --------------------------------------------------------------------------- + +class TestSerializeTriple: + + def test_serialize_iri_triple(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/rel"), + o=Term(type=IRI, iri="http://example.org/B"), + ) + result = json.loads(serialize_triple(t)) + assert result["s"]["type"] == "i" + assert result["s"]["iri"] == "http://example.org/A" + assert result["p"]["iri"] == "http://example.org/rel" + assert result["o"]["type"] == "i" + + def test_serialize_literal_object(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="hello"), + ) + result = json.loads(serialize_triple(t)) + assert result["o"]["type"] == "l" + assert result["o"]["value"] == "hello" + + def test_serialize_typed_literal(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="42", + datatype="http://www.w3.org/2001/XMLSchema#integer"), + ) + result = json.loads(serialize_triple(t)) + assert result["o"]["datatype"] == "http://www.w3.org/2001/XMLSchema#integer" + + def test_serialize_language_tagged_literal(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="bonjour", language="fr"), + ) + result = json.loads(serialize_triple(t)) + assert result["o"]["language"] == "fr" + + def test_serialize_blank_node(self): + t = Triple( + s=Term(type=BLANK, id="_:b0"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="v"), + ) + result = json.loads(serialize_triple(t)) + assert result["s"]["type"] == "b" + assert result["s"]["id"] == "_:b0" + + def test_serialize_nested_quoted_triple(self): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/inner-s"), + p=Term(type=IRI, iri="http://example.org/inner-p"), + o=Term(type=LITERAL, value="inner-val"), + ) + outer = Triple( + s=Term(type=IRI, iri="http://example.org/outer-s"), + p=Term(type=IRI, iri="http://example.org/outer-p"), + o=Term(type=TRIPLE, triple=inner), + ) + result = json.loads(serialize_triple(outer)) + nested = json.loads(result["o"]["triple"]) + assert nested["s"]["iri"] == "http://example.org/inner-s" + assert nested["o"]["value"] == "inner-val" + + def test_serialize_none_returns_none(self): + assert serialize_triple(None) is None + + def test_serialize_none_terms(self): + t = Triple(s=None, p=None, o=None) + result = json.loads(serialize_triple(t)) + assert result["s"] is None + assert result["p"] is None + assert result["o"] is None + + def test_serialize_plain_literal_omits_datatype_and_language(self): + t = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="plain"), + ) + result = json.loads(serialize_triple(t)) + assert "datatype" not in result["o"] + assert "language" not in result["o"] diff --git a/tests/unit/test_rdf/test_rdf_wire_format.py b/tests/unit/test_rdf/test_rdf_wire_format.py new file mode 100644 index 00000000..a0bbd27a --- /dev/null +++ b/tests/unit/test_rdf/test_rdf_wire_format.py @@ -0,0 +1,357 @@ +""" +Tests for RDF wire format translators: TermTranslator and TripleTranslator +round-trip encoding for all RDF 1.2 term types (IRI, blank node, typed literal, +language-tagged literal, quoted triple) and named graph quads. +""" + +import pytest + +from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE +from trustgraph.messaging.translators.primitives import ( + TermTranslator, TripleTranslator, SubgraphTranslator, +) + + +@pytest.fixture +def term_tx(): + return TermTranslator() + + +@pytest.fixture +def triple_tx(): + return TripleTranslator() + + +# --------------------------------------------------------------------------- +# TermTranslator — IRI +# --------------------------------------------------------------------------- + +class TestTermTranslatorIri: + + def test_iri_to_pulsar(self, term_tx): + data = {"t": "i", "i": "http://example.org/Alice"} + term = term_tx.to_pulsar(data) + assert term.type == IRI + assert term.iri == "http://example.org/Alice" + + def test_iri_from_pulsar(self, term_tx): + term = Term(type=IRI, iri="http://example.org/Bob") + wire = term_tx.from_pulsar(term) + assert wire == {"t": "i", "i": "http://example.org/Bob"} + + def test_iri_round_trip(self, term_tx): + original = Term(type=IRI, iri="http://example.org/round") + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored == original + + +# --------------------------------------------------------------------------- +# TermTranslator — Blank node +# --------------------------------------------------------------------------- + +class TestTermTranslatorBlank: + + def test_blank_to_pulsar(self, term_tx): + data = {"t": "b", "d": "_:b42"} + term = term_tx.to_pulsar(data) + assert term.type == BLANK + assert term.id == "_:b42" + + def test_blank_from_pulsar(self, term_tx): + term = Term(type=BLANK, id="_:node1") + wire = term_tx.from_pulsar(term) + assert wire == {"t": "b", "d": "_:node1"} + + def test_blank_round_trip(self, term_tx): + original = Term(type=BLANK, id="_:x") + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored == original + + +# --------------------------------------------------------------------------- +# TermTranslator — Typed literal (XSD) +# --------------------------------------------------------------------------- + +class TestTermTranslatorTypedLiteral: + + def test_plain_literal_to_pulsar(self, term_tx): + data = {"t": "l", "v": "hello"} + term = term_tx.to_pulsar(data) + assert term.type == LITERAL + assert term.value == "hello" + assert term.datatype == "" + assert term.language == "" + + def test_xsd_integer_to_pulsar(self, term_tx): + data = { + "t": "l", "v": "42", + "dt": "http://www.w3.org/2001/XMLSchema#integer", + } + term = term_tx.to_pulsar(data) + assert term.value == "42" + assert term.datatype.endswith("#integer") + + def test_typed_literal_from_pulsar(self, term_tx): + term = Term( + type=LITERAL, value="3.14", + datatype="http://www.w3.org/2001/XMLSchema#double", + ) + wire = term_tx.from_pulsar(term) + assert wire["t"] == "l" + assert wire["v"] == "3.14" + assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double" + assert "ln" not in wire # No language tag + + def test_typed_literal_round_trip(self, term_tx): + original = Term( + type=LITERAL, value="true", + datatype="http://www.w3.org/2001/XMLSchema#boolean", + ) + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored == original + + def test_plain_literal_omits_dt_and_ln(self, term_tx): + term = Term(type=LITERAL, value="x") + wire = term_tx.from_pulsar(term) + assert "dt" not in wire + assert "ln" not in wire + + +# --------------------------------------------------------------------------- +# TermTranslator — Language-tagged literal +# --------------------------------------------------------------------------- + +class TestTermTranslatorLangLiteral: + + def test_language_tag_to_pulsar(self, term_tx): + data = {"t": "l", "v": "bonjour", "ln": "fr"} + term = term_tx.to_pulsar(data) + assert term.value == "bonjour" + assert term.language == "fr" + + def test_language_tag_from_pulsar(self, term_tx): + term = Term(type=LITERAL, value="colour", language="en-GB") + wire = term_tx.from_pulsar(term) + assert wire["ln"] == "en-GB" + assert "dt" not in wire # No datatype + + def test_language_tag_round_trip(self, term_tx): + original = Term(type=LITERAL, value="hola", language="es") + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored == original + + +# --------------------------------------------------------------------------- +# TermTranslator — Quoted triple (RDF-star) +# --------------------------------------------------------------------------- + +class TestTermTranslatorQuotedTriple: + + def test_quoted_triple_to_pulsar(self, term_tx): + data = { + "t": "t", + "tr": { + "s": {"t": "i", "i": "http://example.org/Alice"}, + "p": {"t": "i", "i": "http://xmlns.com/foaf/0.1/knows"}, + "o": {"t": "i", "i": "http://example.org/Bob"}, + }, + } + term = term_tx.to_pulsar(data) + assert term.type == TRIPLE + assert term.triple is not None + assert term.triple.s.iri == "http://example.org/Alice" + assert term.triple.o.iri == "http://example.org/Bob" + + def test_quoted_triple_from_pulsar(self, term_tx): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="val"), + ) + term = Term(type=TRIPLE, triple=inner) + wire = term_tx.from_pulsar(term) + assert wire["t"] == "t" + assert "tr" in wire + assert wire["tr"]["s"]["i"] == "http://example.org/s" + assert wire["tr"]["o"]["v"] == "val" + + def test_quoted_triple_round_trip(self, term_tx): + inner = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/B"), + o=Term(type=LITERAL, value="C", language="en"), + ) + original = Term(type=TRIPLE, triple=inner) + wire = term_tx.from_pulsar(original) + restored = term_tx.to_pulsar(wire) + assert restored.type == TRIPLE + assert restored.triple.s == original.triple.s + assert restored.triple.o == original.triple.o + + def test_quoted_triple_none_triple(self, term_tx): + term = Term(type=TRIPLE, triple=None) + wire = term_tx.from_pulsar(term) + assert wire == {"t": "t"} + # And back + restored = term_tx.to_pulsar(wire) + assert restored.type == TRIPLE + assert restored.triple is None + + def test_quoted_triple_with_literal_object(self, term_tx): + data = { + "t": "t", + "tr": { + "s": {"t": "i", "i": "http://example.org/Hope"}, + "p": {"t": "i", "i": "http://www.w3.org/2004/02/skos/core#definition"}, + "o": {"t": "l", "v": "A feeling of expectation"}, + }, + } + term = term_tx.to_pulsar(data) + assert term.triple.o.type == LITERAL + assert term.triple.o.value == "A feeling of expectation" + + +# --------------------------------------------------------------------------- +# TermTranslator — Edge cases +# --------------------------------------------------------------------------- + +class TestTermTranslatorEdgeCases: + + def test_unknown_type(self, term_tx): + data = {"t": "z"} + term = term_tx.to_pulsar(data) + assert term.type == "z" + + def test_empty_type(self, term_tx): + data = {} + term = term_tx.to_pulsar(data) + assert term.type == "" + + def test_missing_iri_field(self, term_tx): + data = {"t": "i"} + term = term_tx.to_pulsar(data) + assert term.iri == "" + + def test_missing_literal_fields(self, term_tx): + data = {"t": "l"} + term = term_tx.to_pulsar(data) + assert term.value == "" + assert term.datatype == "" + assert term.language == "" + + +# --------------------------------------------------------------------------- +# TripleTranslator +# --------------------------------------------------------------------------- + +class TestTripleTranslator: + + def test_triple_to_pulsar(self, triple_tx): + data = { + "s": {"t": "i", "i": "http://example.org/s"}, + "p": {"t": "i", "i": "http://example.org/p"}, + "o": {"t": "l", "v": "object"}, + } + triple = triple_tx.to_pulsar(data) + assert triple.s.iri == "http://example.org/s" + assert triple.o.value == "object" + assert triple.g is None + + def test_triple_from_pulsar(self, triple_tx): + triple = Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/B"), + o=Term(type=LITERAL, value="C"), + ) + wire = triple_tx.from_pulsar(triple) + assert wire["s"]["t"] == "i" + assert wire["o"]["v"] == "C" + assert "g" not in wire + + def test_quad_with_named_graph(self, triple_tx): + data = { + "s": {"t": "i", "i": "http://example.org/s"}, + "p": {"t": "i", "i": "http://example.org/p"}, + "o": {"t": "l", "v": "val"}, + "g": "urn:graph:source", + } + quad = triple_tx.to_pulsar(data) + assert quad.g == "urn:graph:source" + + def test_quad_from_pulsar_includes_graph(self, triple_tx): + quad = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="v"), + g="urn:graph:retrieval", + ) + wire = triple_tx.from_pulsar(quad) + assert wire["g"] == "urn:graph:retrieval" + + def test_quad_round_trip(self, triple_tx): + original = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="v"), + g="urn:graph:source", + ) + wire = triple_tx.from_pulsar(original) + restored = triple_tx.to_pulsar(wire) + assert restored == original + + def test_none_graph_omitted_from_wire(self, triple_tx): + triple = Triple( + s=Term(type=IRI, iri="http://example.org/s"), + p=Term(type=IRI, iri="http://example.org/p"), + o=Term(type=LITERAL, value="v"), + g=None, + ) + wire = triple_tx.from_pulsar(triple) + assert "g" not in wire + + def test_missing_terms_handled(self, triple_tx): + data = {} + triple = triple_tx.to_pulsar(data) + assert triple.s is None + assert triple.p is None + assert triple.o is None + + +# --------------------------------------------------------------------------- +# SubgraphTranslator +# --------------------------------------------------------------------------- + +class TestSubgraphTranslator: + + def test_subgraph_round_trip(self): + tx = SubgraphTranslator() + triples = [ + Triple( + s=Term(type=IRI, iri="http://example.org/A"), + p=Term(type=IRI, iri="http://example.org/rel"), + o=Term(type=LITERAL, value="v1"), + ), + Triple( + s=Term(type=IRI, iri="http://example.org/B"), + p=Term(type=IRI, iri="http://example.org/rel"), + o=Term(type=IRI, iri="http://example.org/C"), + g="urn:graph:source", + ), + ] + wire_list = tx.from_pulsar(triples) + assert len(wire_list) == 2 + assert wire_list[1]["g"] == "urn:graph:source" + + restored = tx.to_pulsar(wire_list) + assert len(restored) == 2 + assert restored[0] == triples[0] + assert restored[1] == triples[1] + + def test_empty_subgraph(self): + tx = SubgraphTranslator() + assert tx.to_pulsar([]) == [] + assert tx.from_pulsar([]) == [] diff --git a/tests/unit/test_reliability/__init__.py b/tests/unit/test_reliability/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_reliability/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_reliability/test_metadata_preservation.py b/tests/unit/test_reliability/test_metadata_preservation.py new file mode 100644 index 00000000..2fabed58 --- /dev/null +++ b/tests/unit/test_reliability/test_metadata_preservation.py @@ -0,0 +1,144 @@ +""" +Tests for pipeline metadata preservation: DocumentMetadata and +ProcessingMetadata round-trip through translators, field preservation, +and default handling. +""" + +import pytest + +from trustgraph.schema import DocumentMetadata, ProcessingMetadata, Triple, Term, IRI +from trustgraph.messaging.translators.metadata import ( + DocumentMetadataTranslator, + ProcessingMetadataTranslator, +) + + +# --------------------------------------------------------------------------- +# DocumentMetadata translator +# --------------------------------------------------------------------------- + +class TestDocumentMetadataTranslator: + + def setup_method(self): + self.tx = DocumentMetadataTranslator() + + def test_full_round_trip(self): + data = { + "id": "doc-123", + "time": 1710000000, + "kind": "application/pdf", + "title": "Test Document", + "comments": "No comments", + "metadata": [], + "user": "alice", + "tags": ["finance", "q4"], + "parent-id": "doc-100", + "document-type": "page", + } + obj = self.tx.to_pulsar(data) + assert obj.id == "doc-123" + assert obj.time == 1710000000 + assert obj.kind == "application/pdf" + assert obj.title == "Test Document" + assert obj.user == "alice" + assert obj.tags == ["finance", "q4"] + assert obj.parent_id == "doc-100" + assert obj.document_type == "page" + + wire = self.tx.from_pulsar(obj) + assert wire["id"] == "doc-123" + assert wire["user"] == "alice" + assert wire["parent-id"] == "doc-100" + assert wire["document-type"] == "page" + + def test_defaults_for_missing_fields(self): + obj = self.tx.to_pulsar({}) + assert obj.parent_id == "" + assert obj.document_type == "source" + + def test_metadata_triples_preserved(self): + triple_wire = [{ + "s": {"t": "i", "i": "http://example.org/s"}, + "p": {"t": "i", "i": "http://example.org/p"}, + "o": {"t": "i", "i": "http://example.org/o"}, + }] + data = {"metadata": triple_wire} + obj = self.tx.to_pulsar(data) + assert len(obj.metadata) == 1 + assert obj.metadata[0].s.iri == "http://example.org/s" + + def test_none_metadata_handled(self): + data = {"metadata": None} + obj = self.tx.to_pulsar(data) + assert obj.metadata == [] + + def test_empty_tags_preserved(self): + data = {"tags": []} + obj = self.tx.to_pulsar(data) + wire = self.tx.from_pulsar(obj) + assert wire["tags"] == [] + + def test_falsy_fields_omitted_from_wire(self): + """Empty string fields should be omitted from wire format.""" + obj = DocumentMetadata(id="", time=0, user="") + wire = self.tx.from_pulsar(obj) + assert "id" not in wire + assert "user" not in wire + + +# --------------------------------------------------------------------------- +# ProcessingMetadata translator +# --------------------------------------------------------------------------- + +class TestProcessingMetadataTranslator: + + def setup_method(self): + self.tx = ProcessingMetadataTranslator() + + def test_full_round_trip(self): + data = { + "id": "proc-1", + "document-id": "doc-123", + "time": 1710000000, + "flow": "default", + "user": "alice", + "collection": "my-collection", + "tags": ["tag1"], + } + obj = self.tx.to_pulsar(data) + assert obj.id == "proc-1" + assert obj.document_id == "doc-123" + assert obj.flow == "default" + assert obj.user == "alice" + assert obj.collection == "my-collection" + assert obj.tags == ["tag1"] + + wire = self.tx.from_pulsar(obj) + assert wire["id"] == "proc-1" + assert wire["document-id"] == "doc-123" + assert wire["user"] == "alice" + assert wire["collection"] == "my-collection" + + def test_missing_fields_use_defaults(self): + obj = self.tx.to_pulsar({}) + assert obj.id is None + assert obj.user is None + assert obj.collection is None + + def test_tags_none_omitted(self): + obj = ProcessingMetadata(tags=None) + wire = self.tx.from_pulsar(obj) + assert "tags" not in wire + + def test_tags_empty_list_preserved(self): + obj = ProcessingMetadata(tags=[]) + wire = self.tx.from_pulsar(obj) + assert wire["tags"] == [] + + def test_user_and_collection_preserved(self): + """Core pipeline routing fields must survive round-trip.""" + data = {"user": "bob", "collection": "research"} + obj = self.tx.to_pulsar(data) + wire = self.tx.from_pulsar(obj) + assert wire["user"] == "bob" + assert wire["collection"] == "research" diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py new file mode 100644 index 00000000..41a5c621 --- /dev/null +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -0,0 +1,314 @@ +""" +Tests for null embedding protection: empty/None vector skipping, entity +validation, dimension-aware collection creation, and query-time empty +vector handling. + +Tests the pure functions and logic without Qdrant connections. +""" + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from trustgraph.schema import Term, IRI, LITERAL, BLANK + + +# --------------------------------------------------------------------------- +# Graph embeddings: get_term_value +# --------------------------------------------------------------------------- + +class TestGraphEmbeddingsGetTermValue: + + def test_iri_returns_iri(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + t = Term(type=IRI, iri="http://example.org/x") + assert get_term_value(t) == "http://example.org/x" + + def test_literal_returns_value(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + t = Term(type=LITERAL, value="hello") + assert get_term_value(t) == "hello" + + def test_blank_returns_id(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + t = Term(type=BLANK, id="_:b0") + assert get_term_value(t) == "_:b0" + + def test_none_returns_none(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + assert get_term_value(None) is None + + def test_blank_with_value_fallback(self): + from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value + t = Term(type=BLANK, id="", value="fallback") + assert get_term_value(t) == "fallback" + + +# --------------------------------------------------------------------------- +# Document embeddings: null vector protection +# --------------------------------------------------------------------------- + +class TestDocEmbeddingsNullProtection: + + @pytest.mark.asyncio + async def test_empty_vector_skipped(self): + """Embeddings with empty vectors should be silently skipped.""" + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + + # Mock collection_exists for config check + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + emb = MagicMock() + emb.chunk_id = "chunk-1" + emb.vector = [] # Empty vector + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + + # No upsert should be called + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_none_vector_skipped(self): + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + emb = MagicMock() + emb.chunk_id = "chunk-1" + emb.vector = None # None vector + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_empty_chunk_id_skipped(self): + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + emb = MagicMock() + emb.chunk_id = "" # Empty chunk ID + emb.vector = [0.1, 0.2, 0.3] + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_valid_embedding_upserted(self): + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.qdrant.collection_exists.return_value = True + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + emb = MagicMock() + emb.chunk_id = "chunk-1" + emb.vector = [0.1, 0.2, 0.3] + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + proc.qdrant.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_dimension_in_collection_name(self): + """Collection name should include vector dimension.""" + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.qdrant.collection_exists.return_value = True + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "alice" + msg.metadata.collection = "docs" + + emb = MagicMock() + emb.chunk_id = "c1" + emb.vector = [0.0] * 384 # 384-dim vector + msg.chunks = [emb] + + await proc.store_document_embeddings(msg) + + call_args = proc.qdrant.upsert.call_args + assert "d_alice_docs_384" in call_args[1]["collection_name"] + + +# --------------------------------------------------------------------------- +# Graph embeddings: null entity and vector protection +# --------------------------------------------------------------------------- + +class TestGraphEmbeddingsNullProtection: + + @pytest.mark.asyncio + async def test_empty_entity_skipped(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + entity = MagicMock() + entity.entity = Term(type=IRI, iri="") # Empty IRI + entity.vector = [0.1, 0.2, 0.3] + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_none_entity_skipped(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + entity = MagicMock() + entity.entity = None # Null entity + entity.vector = [0.1, 0.2, 0.3] + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_empty_vector_skipped(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + entity = MagicMock() + entity.entity = Term(type=IRI, iri="http://example.org/x") + entity.vector = [] # Empty vector + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_valid_entity_and_vector_upserted(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.qdrant.collection_exists.return_value = True + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "col1" + + entity = MagicMock() + entity.entity = Term(type=IRI, iri="http://example.org/Alice") + entity.vector = [0.1, 0.2, 0.3] + entity.chunk_id = "c1" + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_lazy_collection_creation_on_new_dimension(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.qdrant.collection_exists.return_value = False + proc.collection_exists = MagicMock(return_value=True) + + msg = MagicMock() + msg.metadata.user = "alice" + msg.metadata.collection = "graphs" + + entity = MagicMock() + entity.entity = Term(type=IRI, iri="http://example.org/x") + entity.vector = [0.0] * 768 + entity.chunk_id = "" + msg.entities = [entity] + + await proc.store_graph_embeddings(msg) + + # Collection should be created with correct dimension + proc.qdrant.create_collection.assert_called_once() + create_args = proc.qdrant.create_collection.call_args + assert create_args[1]["collection_name"] == "t_alice_graphs_768" + + +# --------------------------------------------------------------------------- +# Collection validation — deleted-while-in-flight protection +# --------------------------------------------------------------------------- + +class TestCollectionValidation: + + @pytest.mark.asyncio + async def test_doc_embeddings_dropped_for_deleted_collection(self): + from trustgraph.storage.doc_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=False) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "deleted-col" + msg.chunks = [MagicMock()] + + await proc.store_document_embeddings(msg) + proc.qdrant.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_graph_embeddings_dropped_for_deleted_collection(self): + from trustgraph.storage.graph_embeddings.qdrant.write import Processor + + proc = Processor.__new__(Processor) + proc.qdrant = MagicMock() + proc.collection_exists = MagicMock(return_value=False) + + msg = MagicMock() + msg.metadata.user = "user1" + msg.metadata.collection = "deleted-col" + msg.entities = [MagicMock()] + + await proc.store_graph_embeddings(msg) + proc.qdrant.upsert.assert_not_called() diff --git a/tests/unit/test_reliability/test_retry_backoff.py b/tests/unit/test_reliability/test_retry_backoff.py new file mode 100644 index 00000000..94a3e806 --- /dev/null +++ b/tests/unit/test_reliability/test_retry_backoff.py @@ -0,0 +1,153 @@ +""" +Tests for retry and backoff strategies: Consumer rate-limit retry loop, +timeout expiry, TooManyRequests exception propagation, and configurable +retry parameters. +""" + +import asyncio +import time +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.exceptions import TooManyRequests +from trustgraph.base.consumer import Consumer + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_consumer(rate_limit_retry_time=10, rate_limit_timeout=7200): + """Create a Consumer with minimal mocking.""" + consumer = Consumer.__new__(Consumer) + consumer.rate_limit_retry_time = rate_limit_retry_time + consumer.rate_limit_timeout = rate_limit_timeout + consumer.metrics = None + consumer.consumer = MagicMock() + return consumer + + +# --------------------------------------------------------------------------- +# TooManyRequests exception +# --------------------------------------------------------------------------- + +class TestTooManyRequestsException: + + def test_is_exception(self): + assert issubclass(TooManyRequests, Exception) + + def test_with_message(self): + err = TooManyRequests("rate limited") + assert "rate limited" in str(err) + + def test_without_message(self): + err = TooManyRequests() + assert isinstance(err, TooManyRequests) + + +# --------------------------------------------------------------------------- +# Consumer retry configuration +# --------------------------------------------------------------------------- + +class TestConsumerRetryConfig: + + def test_default_retry_time(self): + consumer = _make_consumer() + assert consumer.rate_limit_retry_time == 10 + + def test_default_timeout(self): + consumer = _make_consumer() + assert consumer.rate_limit_timeout == 7200 + + def test_custom_retry_time(self): + consumer = _make_consumer(rate_limit_retry_time=5) + assert consumer.rate_limit_retry_time == 5 + + def test_custom_timeout(self): + consumer = _make_consumer(rate_limit_timeout=300) + assert consumer.rate_limit_timeout == 300 + + +# --------------------------------------------------------------------------- +# Rate limit metrics +# --------------------------------------------------------------------------- + +class TestRateLimitMetrics: + + def test_metrics_rate_limit_called(self): + """Metrics should record rate limit events when available.""" + consumer = _make_consumer() + consumer.metrics = MagicMock() + + # Simulate what the consumer does on rate limit + consumer.metrics.rate_limit() + + consumer.metrics.rate_limit.assert_called_once() + + +# --------------------------------------------------------------------------- +# Message acknowledgment on error +# --------------------------------------------------------------------------- + +class TestMessageAckOnError: + + def test_consumer_has_negative_acknowledge(self): + """Consumer backend should support negative acknowledgment.""" + consumer = _make_consumer() + msg = MagicMock() + + # Simulate negative ack (what happens on timeout expiry) + consumer.consumer.negative_acknowledge(msg) + consumer.consumer.negative_acknowledge.assert_called_once_with(msg) + + +# --------------------------------------------------------------------------- +# TooManyRequests propagation across services +# --------------------------------------------------------------------------- + +class TestTooManyRequestsPropagation: + + def test_llm_service_propagates(self): + """LLM services should re-raise TooManyRequests for consumer retry.""" + with pytest.raises(TooManyRequests): + raise TooManyRequests() + + def test_embeddings_service_propagates(self): + """Embeddings services should re-raise TooManyRequests for consumer retry.""" + with pytest.raises(TooManyRequests): + try: + raise TooManyRequests("rate limited") + except TooManyRequests as e: + # Re-raise pattern used in services + assert isinstance(e, TooManyRequests) + raise + + def test_too_many_requests_not_caught_by_generic(self): + """TooManyRequests should be distinguishable from generic exceptions.""" + caught_specific = False + try: + raise TooManyRequests("rate limited") + except TooManyRequests: + caught_specific = True + except Exception: + pass + assert caught_specific + + +# --------------------------------------------------------------------------- +# Client-side error type mapping +# --------------------------------------------------------------------------- + +class TestClientErrorTypeMapping: + + def test_too_many_requests_wire_type(self): + """The wire format error type for rate limiting is 'too-many-requests'.""" + from trustgraph.schema import Error + err = Error(type="too-many-requests", message="slow down") + assert err.type == "too-many-requests" + + def test_generic_error_wire_type(self): + from trustgraph.schema import Error + err = Error(type="internal-error", message="something broke") + assert err.type == "internal-error" + assert err.type != "too-many-requests" diff --git a/tests/unit/test_reliability/test_subscriber_resilience.py b/tests/unit/test_reliability/test_subscriber_resilience.py new file mode 100644 index 00000000..4aac1161 --- /dev/null +++ b/tests/unit/test_reliability/test_subscriber_resilience.py @@ -0,0 +1,233 @@ +""" +Tests for message queue subscriber resilience: unexpected message handling, +orphaned message detection, backpressure strategies, graceful draining, +and timeout recovery. +""" + +import asyncio +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.base.subscriber import Subscriber + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_subscriber(max_size=10, backpressure_strategy="block", + drain_timeout=5.0): + """Create a Subscriber without connecting to any backend.""" + backend = MagicMock() + sub = Subscriber( + backend=backend, + topic="test-topic", + subscription="test-sub", + consumer_name="test", + max_size=max_size, + backpressure_strategy=backpressure_strategy, + drain_timeout=drain_timeout, + ) + sub.consumer = MagicMock() + return sub + + +def _make_msg(id=None, value="test-value"): + """Create a mock message with optional properties.""" + msg = MagicMock() + if id is not None: + msg.properties.return_value = {"id": id} + else: + msg.properties.side_effect = KeyError("id") + msg.value.return_value = value + return msg + + +# --------------------------------------------------------------------------- +# Message property extraction resilience +# --------------------------------------------------------------------------- + +class TestMessagePropertyResilience: + + @pytest.mark.asyncio + async def test_missing_id_property_handled(self): + """Messages without 'id' property should not crash.""" + sub = _make_subscriber() + msg = MagicMock() + msg.properties.side_effect = Exception("no properties") + msg.value.return_value = "some-value" + + # Should not raise + await sub._process_message(msg) + + # Message should still be acknowledged + sub.consumer.acknowledge.assert_called_once_with(msg) + + @pytest.mark.asyncio + async def test_message_with_valid_id_delivered(self): + """Messages with matching subscriber ID should be delivered.""" + sub = _make_subscriber() + q = await sub.subscribe("req-1") + + msg = _make_msg(id="req-1", value="response-data") + await sub._process_message(msg) + + assert not q.empty() + assert q.get_nowait() == "response-data" + sub.consumer.acknowledge.assert_called_once() + + +# --------------------------------------------------------------------------- +# Orphaned message handling +# --------------------------------------------------------------------------- + +class TestOrphanedMessages: + + @pytest.mark.asyncio + async def test_orphaned_message_acknowledged(self): + """Messages with no matching waiter should still be acknowledged.""" + sub = _make_subscriber() + msg = _make_msg(id="unknown-id", value="orphan") + + await sub._process_message(msg) + + # Orphaned message is acknowledged (not negative-acknowledged) + sub.consumer.acknowledge.assert_called_once_with(msg) + + @pytest.mark.asyncio + async def test_orphaned_message_not_queued(self): + """Orphaned messages should not appear in any subscriber queue.""" + sub = _make_subscriber() + q = await sub.subscribe("req-1") + + msg = _make_msg(id="different-id", value="orphan") + await sub._process_message(msg) + + assert q.empty() + + +# --------------------------------------------------------------------------- +# Backpressure strategies +# --------------------------------------------------------------------------- + +class TestBackpressureStrategies: + + @pytest.mark.asyncio + async def test_drop_new_rejects_when_full(self): + """drop_new strategy should reject new messages when queue is full.""" + sub = _make_subscriber(max_size=1, backpressure_strategy="drop_new") + q = await sub.subscribe("req-1") + + # Fill the queue + msg1 = _make_msg(id="req-1", value="first") + await sub._process_message(msg1) + assert q.qsize() == 1 + + # Second message should be dropped + msg2 = _make_msg(id="req-1", value="second") + await sub._process_message(msg2) + + # Queue still has only the first message + assert q.qsize() == 1 + assert q.get_nowait() == "first" + + @pytest.mark.asyncio + async def test_drop_oldest_evicts_when_full(self): + """drop_oldest strategy should evict oldest message when full.""" + sub = _make_subscriber(max_size=1, backpressure_strategy="drop_oldest") + q = await sub.subscribe("req-1") + + msg1 = _make_msg(id="req-1", value="first") + await sub._process_message(msg1) + + msg2 = _make_msg(id="req-1", value="second") + await sub._process_message(msg2) + + # Queue should have the newer message + assert q.qsize() == 1 + assert q.get_nowait() == "second" + + @pytest.mark.asyncio + async def test_block_strategy_delivers(self): + """block strategy should deliver messages normally.""" + sub = _make_subscriber(max_size=10, backpressure_strategy="block") + q = await sub.subscribe("req-1") + + msg = _make_msg(id="req-1", value="data") + await sub._process_message(msg) + + assert q.get_nowait() == "data" + + +# --------------------------------------------------------------------------- +# Full subscribers (subscribe_all) +# --------------------------------------------------------------------------- + +class TestFullSubscribers: + + @pytest.mark.asyncio + async def test_subscribe_all_receives_all_messages(self): + sub = _make_subscriber() + q = await sub.subscribe_all("listener-1") + + msg = _make_msg(id="any-id", value="broadcast") + await sub._process_message(msg) + + assert q.get_nowait() == "broadcast" + + @pytest.mark.asyncio + async def test_multiple_full_subscribers_all_receive(self): + sub = _make_subscriber() + q1 = await sub.subscribe_all("l1") + q2 = await sub.subscribe_all("l2") + + msg = _make_msg(id="any", value="data") + await sub._process_message(msg) + + assert q1.get_nowait() == "data" + assert q2.get_nowait() == "data" + + +# --------------------------------------------------------------------------- +# Subscribe / unsubscribe lifecycle +# --------------------------------------------------------------------------- + +class TestSubscribeLifecycle: + + @pytest.mark.asyncio + async def test_unsubscribe_removes_queue(self): + sub = _make_subscriber() + await sub.subscribe("req-1") + await sub.unsubscribe("req-1") + + assert "req-1" not in sub.q + + @pytest.mark.asyncio + async def test_unsubscribe_nonexistent_is_noop(self): + sub = _make_subscriber() + await sub.unsubscribe("nonexistent") # Should not raise + + @pytest.mark.asyncio + async def test_unsubscribe_all_removes_queue(self): + sub = _make_subscriber() + await sub.subscribe_all("l1") + await sub.unsubscribe_all("l1") + + assert "l1" not in sub.full + + +# --------------------------------------------------------------------------- +# Pending ack tracking +# --------------------------------------------------------------------------- + +class TestPendingAckTracking: + + @pytest.mark.asyncio + async def test_processed_message_cleared_from_pending(self): + sub = _make_subscriber() + msg = _make_msg(id="req-1", value="data") + + await sub._process_message(msg) + + # After processing, pending_acks should be empty + assert len(sub.pending_acks) == 0 diff --git a/tests/unit/test_structured_data/__init__.py b/tests/unit/test_structured_data/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit/test_structured_data/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py new file mode 100644 index 00000000..3222ec83 --- /dev/null +++ b/tests/unit/test_structured_data/test_row_embeddings_query.py @@ -0,0 +1,296 @@ +""" +Tests for row embeddings query service: collection naming, query execution, +index filtering, result conversion, and error handling. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.schema import ( + RowEmbeddingsRequest, RowEmbeddingsResponse, + RowIndexMatch, Error, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_processor(qdrant_client=None): + """Create a Processor without full FlowProcessor init.""" + from trustgraph.query.row_embeddings.qdrant.service import Processor + proc = Processor.__new__(Processor) + proc.qdrant = qdrant_client or MagicMock() + return proc + + +def _make_request(vector=None, user="test-user", collection="test-col", + schema_name="customers", limit=10, index_name=None): + return RowEmbeddingsRequest( + vector=vector or [0.1, 0.2, 0.3], + user=user, + collection=collection, + schema_name=schema_name, + limit=limit, + index_name=index_name or "", + ) + + +def _make_search_point(index_name, index_value, text, score): + point = MagicMock() + point.payload = { + "index_name": index_name, + "index_value": index_value, + "text": text, + } + point.score = score + return point + + +# --------------------------------------------------------------------------- +# sanitize_name +# --------------------------------------------------------------------------- + +class TestSanitizeName: + + def test_simple_name(self): + proc = _make_processor() + assert proc.sanitize_name("customers") == "customers" + + def test_special_chars_replaced(self): + proc = _make_processor() + assert proc.sanitize_name("my-schema.v2") == "my_schema_v2" + + def test_leading_digit_prefixed(self): + proc = _make_processor() + result = proc.sanitize_name("123schema") + assert result.startswith("r_") + assert "123schema" in result + + def test_uppercase_lowercased(self): + proc = _make_processor() + assert proc.sanitize_name("MySchema") == "myschema" + + def test_spaces_replaced(self): + proc = _make_processor() + assert proc.sanitize_name("my schema") == "my_schema" + + +# --------------------------------------------------------------------------- +# find_collection +# --------------------------------------------------------------------------- + +class TestFindCollection: + + def test_finds_matching_collection(self): + proc = _make_processor() + mock_coll = MagicMock() + mock_coll.name = "rows_test_user_test_col_customers_384" + + mock_collections = MagicMock() + mock_collections.collections = [mock_coll] + proc.qdrant.get_collections.return_value = mock_collections + + result = proc.find_collection("test-user", "test-col", "customers") + + # Prefix: rows_test_user_test_col_customers_ + assert result == "rows_test_user_test_col_customers_384" + + def test_returns_none_when_no_match(self): + proc = _make_processor() + mock_coll = MagicMock() + mock_coll.name = "rows_other_user_other_col_schema_768" + + mock_collections = MagicMock() + mock_collections.collections = [mock_coll] + proc.qdrant.get_collections.return_value = mock_collections + + result = proc.find_collection("test-user", "test-col", "customers") + assert result is None + + def test_returns_none_on_error(self): + proc = _make_processor() + proc.qdrant.get_collections.side_effect = Exception("connection error") + + result = proc.find_collection("user", "col", "schema") + assert result is None + + +# --------------------------------------------------------------------------- +# query_row_embeddings +# --------------------------------------------------------------------------- + +class TestQueryRowEmbeddings: + + @pytest.mark.asyncio + async def test_empty_vector_returns_empty(self): + proc = _make_processor() + request = _make_request(vector=[]) + + result = await proc.query_row_embeddings(request) + assert result == [] + + @pytest.mark.asyncio + async def test_no_collection_returns_empty(self): + proc = _make_processor() + proc.find_collection = MagicMock(return_value=None) + request = _make_request() + + result = await proc.query_row_embeddings(request) + assert result == [] + + @pytest.mark.asyncio + async def test_successful_query_returns_matches(self): + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + + points = [ + _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), + _make_search_point("address", ["123 Main St"], "123 Main St", 0.82), + ] + mock_result = MagicMock() + mock_result.points = points + proc.qdrant.query_points.return_value = mock_result + + request = _make_request() + result = await proc.query_row_embeddings(request) + + assert len(result) == 2 + assert isinstance(result[0], RowIndexMatch) + assert result[0].index_name == "name" + assert result[0].index_value == ["Alice Smith"] + assert result[0].score == 0.95 + assert result[1].index_name == "address" + + @pytest.mark.asyncio + async def test_index_name_filter_applied(self): + """When index_name is specified, a Qdrant filter should be used.""" + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + + mock_result = MagicMock() + mock_result.points = [] + proc.qdrant.query_points.return_value = mock_result + + request = _make_request(index_name="address") + await proc.query_row_embeddings(request) + + call_kwargs = proc.qdrant.query_points.call_args[1] + assert call_kwargs["query_filter"] is not None + + @pytest.mark.asyncio + async def test_no_index_name_no_filter(self): + """When index_name is empty, no filter should be applied.""" + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + + mock_result = MagicMock() + mock_result.points = [] + proc.qdrant.query_points.return_value = mock_result + + request = _make_request(index_name="") + await proc.query_row_embeddings(request) + + call_kwargs = proc.qdrant.query_points.call_args[1] + assert call_kwargs["query_filter"] is None + + @pytest.mark.asyncio + async def test_missing_payload_fields_default(self): + """Points with missing payload fields should use defaults.""" + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + + point = MagicMock() + point.payload = {} # Empty payload + point.score = 0.5 + + mock_result = MagicMock() + mock_result.points = [point] + proc.qdrant.query_points.return_value = mock_result + + request = _make_request() + result = await proc.query_row_embeddings(request) + + assert len(result) == 1 + assert result[0].index_name == "" + assert result[0].index_value == [] + assert result[0].text == "" + + @pytest.mark.asyncio + async def test_qdrant_error_propagates(self): + proc = _make_processor() + proc.find_collection = MagicMock(return_value="rows_u_c_s_384") + proc.qdrant.query_points.side_effect = Exception("qdrant down") + + request = _make_request() + + with pytest.raises(Exception, match="qdrant down"): + await proc.query_row_embeddings(request) + + +# --------------------------------------------------------------------------- +# on_message handler +# --------------------------------------------------------------------------- + +class TestOnMessage: + + @pytest.mark.asyncio + async def test_successful_message_sends_response(self): + proc = _make_processor() + proc.query_row_embeddings = AsyncMock(return_value=[ + RowIndexMatch(index_name="name", index_value=["Alice"], + text="Alice", score=0.9), + ]) + + mock_pub = AsyncMock() + flow = lambda name: mock_pub + + msg = MagicMock() + msg.value.return_value = _make_request() + msg.properties.return_value = {"id": "req-1"} + + await proc.on_message(msg, MagicMock(), flow) + + sent = mock_pub.send.call_args[0][0] + assert isinstance(sent, RowEmbeddingsResponse) + assert sent.error is None + assert len(sent.matches) == 1 + + @pytest.mark.asyncio + async def test_error_sends_error_response(self): + proc = _make_processor() + proc.query_row_embeddings = AsyncMock( + side_effect=Exception("query failed") + ) + + mock_pub = AsyncMock() + flow = lambda name: mock_pub + + msg = MagicMock() + msg.value.return_value = _make_request() + msg.properties.return_value = {"id": "req-2"} + + await proc.on_message(msg, MagicMock(), flow) + + sent = mock_pub.send.call_args[0][0] + assert sent.error is not None + assert sent.error.type == "row-embeddings-query-error" + assert "query failed" in sent.error.message + assert sent.matches == [] + + @pytest.mark.asyncio + async def test_message_id_preserved(self): + proc = _make_processor() + proc.query_row_embeddings = AsyncMock(return_value=[]) + + mock_pub = AsyncMock() + flow = lambda name: mock_pub + + msg = MagicMock() + msg.value.return_value = _make_request() + msg.properties.return_value = {"id": "unique-42"} + + await proc.on_message(msg, MagicMock(), flow) + + props = mock_pub.send.call_args[1]["properties"] + assert props["id"] == "unique-42" diff --git a/tests/unit/test_structured_data/test_type_detector.py b/tests/unit/test_structured_data/test_type_detector.py new file mode 100644 index 00000000..7ce7060a --- /dev/null +++ b/tests/unit/test_structured_data/test_type_detector.py @@ -0,0 +1,235 @@ +""" +Tests for structured data type detection: CSV, JSON, XML format detection, +CSV option detection (delimiter, header), and helper functions. +""" + +import pytest + +from trustgraph.retrieval.structured_diag.type_detector import ( + detect_data_type, + _check_json_format, + _check_xml_format, + _check_csv_format, + _check_csv_with_delimiter, + detect_csv_options, + _is_numeric, +) + + +# --------------------------------------------------------------------------- +# detect_data_type (top-level dispatcher) +# --------------------------------------------------------------------------- + +class TestDetectDataType: + + def test_empty_string_returns_none(self): + detected, confidence = detect_data_type("") + assert detected is None + assert confidence == 0.0 + + def test_whitespace_only_returns_none(self): + detected, confidence = detect_data_type(" \n \t ") + assert detected is None + assert confidence == 0.0 + + def test_none_returns_none(self): + detected, confidence = detect_data_type(None) + assert detected is None + assert confidence == 0.0 + + def test_json_object_detected(self): + detected, confidence = detect_data_type('{"name": "Alice"}') + assert detected == "json" + assert confidence > 0.5 + + def test_json_array_detected(self): + detected, confidence = detect_data_type('[{"id": 1}, {"id": 2}]') + assert detected == "json" + assert confidence > 0.5 + + def test_xml_with_declaration_detected(self): + detected, confidence = detect_data_type('') + assert detected == "xml" + assert confidence > 0.5 + + def test_xml_without_declaration_detected(self): + detected, confidence = detect_data_type('val') + assert detected == "xml" + assert confidence > 0.5 + + def test_csv_detected(self): + data = "name,age,city\nAlice,30,NYC\nBob,25,LA" + detected, confidence = detect_data_type(data) + assert detected == "csv" + assert confidence > 0.5 + + def test_plain_text_falls_through_to_csv(self): + """Non-JSON/XML text defaults to CSV detection.""" + detected, confidence = detect_data_type("just some text") + assert detected == "csv" + + +# --------------------------------------------------------------------------- +# _check_json_format +# --------------------------------------------------------------------------- + +class TestCheckJsonFormat: + + def test_valid_json_object(self): + assert _check_json_format('{"key": "value"}') > 0.9 + + def test_valid_json_array_of_objects(self): + assert _check_json_format('[{"id": 1}, {"id": 2}]') >= 0.9 + + def test_valid_json_array_of_primitives(self): + score = _check_json_format('[1, 2, 3]') + assert score > 0.5 + assert score < 0.9 # Lower confidence for non-object arrays + + def test_empty_json_object(self): + assert _check_json_format('{}') > 0.5 + + def test_invalid_json(self): + assert _check_json_format('{invalid json}') == 0.0 + + def test_non_json_starting_char(self): + assert _check_json_format('hello world') == 0.0 + + def test_empty_array(self): + score = _check_json_format('[]') + assert score > 0.0 # Parsed successfully but empty + + +# --------------------------------------------------------------------------- +# _check_xml_format +# --------------------------------------------------------------------------- + +class TestCheckXmlFormat: + + def test_valid_xml(self): + assert _check_xml_format('val') == 0.9 + + def test_xml_with_declaration(self): + xml = 'test' + assert _check_xml_format(xml) == 0.9 + + def test_malformed_xml(self): + score = _check_xml_format('') + # Has < and ') + # Starts with < but no closing tag + assert score <= 0.1 + + +# --------------------------------------------------------------------------- +# _check_csv_format and _check_csv_with_delimiter +# --------------------------------------------------------------------------- + +class TestCheckCsvFormat: + + def test_valid_csv_comma(self): + data = "name,age,city\nAlice,30,NYC\nBob,25,LA" + assert _check_csv_format(data) > 0.7 + + def test_valid_csv_semicolon(self): + data = "name;age;city\nAlice;30;NYC\nBob;25;LA" + assert _check_csv_format(data) > 0.7 + + def test_valid_csv_tab(self): + data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA" + assert _check_csv_format(data) > 0.7 + + def test_valid_csv_pipe(self): + data = "name|age|city\nAlice|30|NYC\nBob|25|LA" + assert _check_csv_format(data) > 0.7 + + def test_single_line_not_csv(self): + assert _check_csv_format("just one line") == 0.0 + + def test_single_column_not_csv(self): + data = "a\nb\nc" + assert _check_csv_with_delimiter(data, ",") == 0.0 + + def test_inconsistent_columns_low_score(self): + data = "a,b,c\n1,2\n3,4,5,6" + score = _check_csv_with_delimiter(data, ",") + assert score < 0.7 + + def test_many_rows_higher_score(self): + rows = ["name,age,city"] + [f"person{i},{20+i},city{i}" for i in range(20)] + data = "\n".join(rows) + score = _check_csv_format(data) + assert score > 0.8 + + +# --------------------------------------------------------------------------- +# detect_csv_options +# --------------------------------------------------------------------------- + +class TestDetectCsvOptions: + + def test_comma_delimiter_detected(self): + data = "name,age,city\nAlice,30,NYC\nBob,25,LA" + options = detect_csv_options(data) + assert options["delimiter"] == "," + + def test_semicolon_delimiter_detected(self): + data = "name;age;city\nAlice;30;NYC\nBob;25;LA" + options = detect_csv_options(data) + assert options["delimiter"] == ";" + + def test_tab_delimiter_detected(self): + data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA" + options = detect_csv_options(data) + assert options["delimiter"] == "\t" + + def test_header_detected_when_first_row_text(self): + data = "name,age,salary\nAlice,30,50000\nBob,25,45000" + options = detect_csv_options(data) + assert options["has_header"] is True + + def test_no_header_when_all_numeric(self): + data = "1,2,3\n4,5,6\n7,8,9" + options = detect_csv_options(data) + assert options["has_header"] is False + + def test_single_line_returns_defaults(self): + options = detect_csv_options("just one line") + assert options["delimiter"] == "," + assert options["has_header"] is True + + def test_encoding_default(self): + data = "a,b\n1,2" + options = detect_csv_options(data) + assert options["encoding"] == "utf-8" + + +# --------------------------------------------------------------------------- +# _is_numeric helper +# --------------------------------------------------------------------------- + +class TestIsNumeric: + + def test_integer(self): + assert _is_numeric("42") is True + + def test_float(self): + assert _is_numeric("3.14") is True + + def test_negative(self): + assert _is_numeric("-10") is True + + def test_text(self): + assert _is_numeric("hello") is False + + def test_empty(self): + assert _is_numeric("") is False + + def test_whitespace_padded(self): + assert _is_numeric(" 42 ") is True diff --git a/tests/unit/test_text_completion/test_azure_openai_streaming.py b/tests/unit/test_text_completion/test_azure_openai_streaming.py new file mode 100644 index 00000000..b2f5a003 --- /dev/null +++ b/tests/unit/test_text_completion/test_azure_openai_streaming.py @@ -0,0 +1,182 @@ +""" +Tests for Azure OpenAI streaming: model/temperature override during streaming, +RateLimitError → TooManyRequests conversion, chunk iteration, and final token +count emission. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.model.text_completion.azure_openai.llm import Processor +from trustgraph.base import LlmChunk +from trustgraph.exceptions import TooManyRequests + + +def _make_processor(mock_azure_openai_class, model="gpt-4"): + """Create a Processor with mocked base classes.""" + with patch('trustgraph.base.async_processor.AsyncProcessor.__init__', + return_value=None), \ + patch('trustgraph.base.llm_service.LlmService.__init__', + return_value=None): + proc = Processor( + endpoint="https://test.openai.azure.com/", + token="test-token", + api_version="2024-12-01-preview", + model=model, + temperature=0.0, + max_output=4192, + concurrency=1, + taskgroup=AsyncMock(), + id="test-processor", + ) + return proc + + +def _make_stream_chunk(content=None, usage=None): + """Create a mock streaming chunk.""" + chunk = MagicMock() + if content: + chunk.choices = [MagicMock()] + chunk.choices[0].delta.content = content + else: + chunk.choices = [] + chunk.usage = usage + return chunk + + +class TestAzureOpenAIStreaming(IsolatedAsyncioTestCase): + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_yields_chunks(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 5 + + stream_data = [ + _make_stream_chunk(content="Hello"), + _make_stream_chunk(content=" world"), + _make_stream_chunk(usage=usage), + ] + mock_client.chat.completions.create.return_value = iter(stream_data) + + results = [] + async for chunk in proc.generate_content_stream("sys", "user"): + results.append(chunk) + + assert len(results) == 3 # 2 content + 1 final + assert results[0].text == "Hello" + assert results[0].is_final is False + assert results[1].text == " world" + assert results[2].is_final is True + assert results[2].in_token == 10 + assert results[2].out_token == 5 + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_model_override(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class, model="gpt-4") + + usage = MagicMock() + usage.prompt_tokens = 5 + usage.completion_tokens = 2 + + stream_data = [ + _make_stream_chunk(content="ok"), + _make_stream_chunk(usage=usage), + ] + mock_client.chat.completions.create.return_value = iter(stream_data) + + results = [] + async for chunk in proc.generate_content_stream( + "sys", "user", model="gpt-4o" + ): + results.append(chunk) + + # All chunks carry overridden model + for r in results: + assert r.model == "gpt-4o" + + # Verify API call used overridden model + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "gpt-4o" + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_temperature_override(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + usage = MagicMock() + usage.prompt_tokens = 5 + usage.completion_tokens = 2 + + stream_data = [_make_stream_chunk(usage=usage)] + mock_client.chat.completions.create.return_value = iter(stream_data) + + async for _ in proc.generate_content_stream( + "sys", "user", temperature=0.7 + ): + pass + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["temperature"] == 0.7 + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_rate_limit_raises_too_many_requests(self, mock_azure_class): + from openai import RateLimitError + + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + mock_client.chat.completions.create.side_effect = RateLimitError( + "Rate limit exceeded", response=MagicMock(), body=None + ) + + with pytest.raises(TooManyRequests): + async for _ in proc.generate_content_stream("sys", "user"): + pass + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_generic_exception_propagates(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + mock_client.chat.completions.create.side_effect = Exception("API down") + + with pytest.raises(Exception, match="API down"): + async for _ in proc.generate_content_stream("sys", "user"): + pass + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_streaming_passes_stream_options(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + + usage = MagicMock() + usage.prompt_tokens = 0 + usage.completion_tokens = 0 + stream_data = [_make_stream_chunk(usage=usage)] + mock_client.chat.completions.create.return_value = iter(stream_data) + + async for _ in proc.generate_content_stream("sys", "user"): + pass + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["stream"] is True + assert call_kwargs["stream_options"] == {"include_usage": True} + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + async def test_supports_streaming(self, mock_azure_class): + mock_client = MagicMock() + mock_azure_class.return_value = mock_client + proc = _make_processor(mock_azure_class) + assert proc.supports_streaming() is True diff --git a/tests/unit/test_text_completion/test_azure_streaming.py b/tests/unit/test_text_completion/test_azure_streaming.py new file mode 100644 index 00000000..ff32e59d --- /dev/null +++ b/tests/unit/test_text_completion/test_azure_streaming.py @@ -0,0 +1,199 @@ +""" +Tests for Azure serverless endpoint streaming: model override during streaming, +HTTP 429 during streaming, SSE chunk parsing, and final token count emission. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.model.text_completion.azure.llm import Processor +from trustgraph.base import LlmChunk +from trustgraph.exceptions import TooManyRequests + + +def _make_processor(mock_requests, model="AzureAI", temperature=0.0): + """Create a Processor with mocked base classes.""" + with patch('trustgraph.base.async_processor.AsyncProcessor.__init__', + return_value=None), \ + patch('trustgraph.base.llm_service.LlmService.__init__', + return_value=None): + proc = Processor( + endpoint="https://test.azure.com/v1/chat/completions", + token="test-token", + temperature=temperature, + max_output=4192, + model=model, + concurrency=1, + taskgroup=AsyncMock(), + id="test-processor", + ) + return proc + + +def _sse_lines(*data_items): + """Build SSE byte lines from data items. '[DONE]' is appended.""" + lines = [] + for item in data_items: + if isinstance(item, dict): + lines.append(f"data: {json.dumps(item)}".encode()) + else: + lines.append(f"data: {item}".encode()) + lines.append(b"data: [DONE]") + return lines + + +class TestAzureServerlessStreaming(IsolatedAsyncioTestCase): + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_yields_chunks(self, mock_requests): + proc = _make_processor(mock_requests) + + chunks = [ + {"choices": [{"delta": {"content": "Hello"}}]}, + {"choices": [{"delta": {"content": " world"}}]}, + {"usage": {"prompt_tokens": 10, "completion_tokens": 5}}, + ] + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = _sse_lines(*chunks) + mock_requests.post.return_value = mock_response + + results = [] + async for chunk in proc.generate_content_stream("sys", "user"): + results.append(chunk) + + # Content chunks + final chunk + assert len(results) == 3 + assert results[0].text == "Hello" + assert results[0].is_final is False + assert results[1].text == " world" + assert results[1].is_final is False + assert results[2].is_final is True + assert results[2].in_token == 10 + assert results[2].out_token == 5 + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_model_override(self, mock_requests): + proc = _make_processor(mock_requests, model="default-model") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = _sse_lines( + {"choices": [{"delta": {"content": "ok"}}]}, + {"usage": {"prompt_tokens": 5, "completion_tokens": 2}}, + ) + mock_requests.post.return_value = mock_response + + results = [] + async for chunk in proc.generate_content_stream( + "sys", "user", model="override-model" + ): + results.append(chunk) + + # All chunks should carry the overridden model name + for r in results: + assert r.model == "override-model" + + # Verify the request body used the overridden model + call_args = mock_requests.post.call_args + body = json.loads(call_args[1]["data"]) + assert body["model"] == "override-model" + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_temperature_override(self, mock_requests): + proc = _make_processor(mock_requests, temperature=0.0) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = _sse_lines( + {"choices": [{"delta": {"content": "ok"}}]}, + {"usage": {"prompt_tokens": 5, "completion_tokens": 2}}, + ) + mock_requests.post.return_value = mock_response + + results = [] + async for chunk in proc.generate_content_stream( + "sys", "user", temperature=0.9 + ): + results.append(chunk) + + call_args = mock_requests.post.call_args + body = json.loads(call_args[1]["data"]) + assert body["temperature"] == 0.9 + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_429_raises_too_many_requests(self, mock_requests): + proc = _make_processor(mock_requests) + + mock_response = MagicMock() + mock_response.status_code = 429 + mock_requests.post.return_value = mock_response + + with pytest.raises(TooManyRequests): + async for _ in proc.generate_content_stream("sys", "user"): + pass + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_http_error_raises_runtime(self, mock_requests): + proc = _make_processor(mock_requests) + + mock_response = MagicMock() + mock_response.status_code = 503 + mock_response.text = "Service Unavailable" + mock_requests.post.return_value = mock_response + + with pytest.raises(RuntimeError, match="HTTP 503"): + async for _ in proc.generate_content_stream("sys", "user"): + pass + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_includes_stream_options(self, mock_requests): + """Verify stream=True and stream_options in request body.""" + proc = _make_processor(mock_requests) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = _sse_lines( + {"usage": {"prompt_tokens": 0, "completion_tokens": 0}}, + ) + mock_requests.post.return_value = mock_response + + async for _ in proc.generate_content_stream("sys", "user"): + pass + + call_args = mock_requests.post.call_args + body = json.loads(call_args[1]["data"]) + assert body["stream"] is True + assert body["stream_options"]["include_usage"] is True + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_malformed_json_skipped(self, mock_requests): + """Malformed JSON chunks should be skipped, not crash the stream.""" + proc = _make_processor(mock_requests) + + mock_response = MagicMock() + mock_response.status_code = 200 + lines = [ + b"data: {not valid json}", + f'data: {json.dumps({"choices": [{"delta": {"content": "ok"}}]})}'.encode(), + f'data: {json.dumps({"usage": {"prompt_tokens": 1, "completion_tokens": 1}})}'.encode(), + b"data: [DONE]", + ] + mock_response.iter_lines.return_value = lines + mock_requests.post.return_value = mock_response + + results = [] + async for chunk in proc.generate_content_stream("sys", "user"): + results.append(chunk) + + # Should get the valid content chunk + final chunk + assert any(r.text == "ok" for r in results) + assert results[-1].is_final is True + + @patch('trustgraph.model.text_completion.azure.llm.requests') + async def test_streaming_supports_streaming_flag(self, mock_requests): + proc = _make_processor(mock_requests) + assert proc.supports_streaming() is True diff --git a/tests/unit/test_text_completion/test_rate_limit_contract.py b/tests/unit/test_text_completion/test_rate_limit_contract.py new file mode 100644 index 00000000..c9df217b --- /dev/null +++ b/tests/unit/test_text_completion/test_rate_limit_contract.py @@ -0,0 +1,140 @@ +""" +Cross-provider rate limit contract tests: verify that every LLM provider +that handles rate limits converts its provider-specific exception to +TooManyRequests consistently. + +Also tests the client-side error translation in the base client. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.exceptions import TooManyRequests + + +class TestAzureServerless429(IsolatedAsyncioTestCase): + """Azure serverless endpoint: HTTP 429 → TooManyRequests""" + + @patch('trustgraph.model.text_completion.azure.llm.requests') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_http_429_raises_too_many_requests(self, _llm, _async, mock_requests): + from trustgraph.model.text_completion.azure.llm import Processor + proc = Processor( + endpoint="https://test.azure.com/v1/chat", + token="t", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + mock_response = MagicMock() + mock_response.status_code = 429 + mock_requests.post.return_value = mock_response + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestAzureOpenAIRateLimit(IsolatedAsyncioTestCase): + """Azure OpenAI: openai.RateLimitError → TooManyRequests""" + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cls): + from openai import RateLimitError + from trustgraph.model.text_completion.azure_openai.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + endpoint="https://test.openai.azure.com/", token="t", + model="gpt-4", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + mock_client.chat.completions.create.side_effect = RateLimitError( + "rate limited", response=MagicMock(), body=None + ) + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestOpenAIRateLimit(IsolatedAsyncioTestCase): + """OpenAI: openai.RateLimitError → TooManyRequests""" + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cls): + from openai import RateLimitError + from trustgraph.model.text_completion.openai.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + mock_client.chat.completions.create.side_effect = RateLimitError( + "rate limited", response=MagicMock(), body=None + ) + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestClaudeRateLimit(IsolatedAsyncioTestCase): + """Claude/Anthropic: anthropic.RateLimitError → TooManyRequests""" + + @patch('trustgraph.model.text_completion.claude.llm.anthropic') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_anthropic): + from trustgraph.model.text_completion.claude.llm import Processor + + mock_client = MagicMock() + mock_anthropic.Anthropic.return_value = mock_client + + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + + mock_anthropic.RateLimitError = type("RateLimitError", (Exception,), {}) + mock_client.messages.create.side_effect = mock_anthropic.RateLimitError( + "rate limited" + ) + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestCohereRateLimit(IsolatedAsyncioTestCase): + """Cohere: cohere.TooManyRequestsError → TooManyRequests""" + + @patch('trustgraph.model.text_completion.cohere.llm.cohere') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere): + from trustgraph.model.text_completion.cohere.llm import Processor + + mock_client = MagicMock() + mock_cohere.Client.return_value = mock_client + + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + + mock_cohere.TooManyRequestsError = type( + "TooManyRequestsError", (Exception,), {} + ) + mock_client.chat.side_effect = mock_cohere.TooManyRequestsError( + "rate limited" + ) + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + +class TestClientSideRateLimitTranslation: + """Client base class: error type 'too-many-requests' → TooManyRequests""" + + def test_error_type_mapping(self): + """The wire format error type string is 'too-many-requests'.""" + from trustgraph.schema import Error + err = Error(type="too-many-requests", message="slow down") + assert err.type == "too-many-requests"