diff --git a/tests/unit/test_agent/test_tool_service_lifecycle.py b/tests/unit/test_agent/test_tool_service_lifecycle.py
new file mode 100644
index 00000000..65cdb542
--- /dev/null
+++ b/tests/unit/test_agent/test_tool_service_lifecycle.py
@@ -0,0 +1,624 @@
+"""
+Tests for tool service lifecycle, invoke contract, streaming responses,
+multi-tenancy, and error propagation.
+
+Tests the actual DynamicToolService, ToolService, and ToolServiceClient
+classes rather than plain dicts.
+"""
+
+import json
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from trustgraph.schema import (
+ ToolServiceRequest, ToolServiceResponse, Error,
+ ToolRequest, ToolResponse,
+)
+from trustgraph.exceptions import TooManyRequests
+
+
+# ---------------------------------------------------------------------------
+# DynamicToolService tests
+# ---------------------------------------------------------------------------
+
+class TestDynamicToolServiceInvokeContract:
+
+ @pytest.mark.asyncio
+ async def test_base_invoke_raises_not_implemented(self):
+ """Base class invoke() should raise NotImplementedError."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+
+ with pytest.raises(NotImplementedError):
+ await svc.invoke("user", {}, {})
+
+ @pytest.mark.asyncio
+ async def test_on_request_calls_invoke_with_parsed_args(self):
+ """on_request should JSON-parse config/arguments and pass to invoke."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+ svc.id = "test-svc"
+ svc.producer = AsyncMock()
+
+ calls = []
+
+ async def tracking_invoke(user, config, arguments):
+ calls.append({"user": user, "config": config, "arguments": arguments})
+ return "ok"
+
+ svc.invoke = tracking_invoke
+
+ # Ensure the class-level metric exists
+ if not hasattr(DynamicToolService, "tool_service_metric"):
+ DynamicToolService.tool_service_metric = MagicMock()
+
+ msg = MagicMock()
+ msg.value.return_value = ToolServiceRequest(
+ user="alice",
+ config='{"style": "pun"}',
+ arguments='{"topic": "cats"}',
+ )
+ msg.properties.return_value = {"id": "req-1"}
+
+ await svc.on_request(msg, MagicMock(), None)
+
+ assert len(calls) == 1
+ assert calls[0]["user"] == "alice"
+ assert calls[0]["config"] == {"style": "pun"}
+ assert calls[0]["arguments"] == {"topic": "cats"}
+
+ @pytest.mark.asyncio
+ async def test_on_request_empty_user_defaults_to_trustgraph(self):
+ """Empty user field should default to 'trustgraph'."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+ svc.id = "test-svc"
+ svc.producer = AsyncMock()
+
+ received_user = None
+
+ async def capture_invoke(user, config, arguments):
+ nonlocal received_user
+ received_user = user
+ return "ok"
+
+ svc.invoke = capture_invoke
+
+ if not hasattr(DynamicToolService, "tool_service_metric"):
+ DynamicToolService.tool_service_metric = MagicMock()
+
+ msg = MagicMock()
+ msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
+ msg.properties.return_value = {"id": "req-2"}
+
+ await svc.on_request(msg, MagicMock(), None)
+
+ assert received_user == "trustgraph"
+
+ @pytest.mark.asyncio
+ async def test_on_request_string_response_sent_directly(self):
+ """String return from invoke → response field is the string."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+ svc.id = "test-svc"
+ svc.producer = AsyncMock()
+
+ async def string_invoke(user, config, arguments):
+ return "hello world"
+
+ svc.invoke = string_invoke
+
+ if not hasattr(DynamicToolService, "tool_service_metric"):
+ DynamicToolService.tool_service_metric = MagicMock()
+
+ msg = MagicMock()
+ msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
+ msg.properties.return_value = {"id": "r1"}
+
+ await svc.on_request(msg, MagicMock(), None)
+
+ sent = svc.producer.send.call_args[0][0]
+ assert isinstance(sent, ToolServiceResponse)
+ assert sent.response == "hello world"
+ assert sent.end_of_stream is True
+ assert sent.error is None
+
+ @pytest.mark.asyncio
+ async def test_on_request_dict_response_json_encoded(self):
+ """Dict return from invoke → response field is JSON-encoded."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+ svc.id = "test-svc"
+ svc.producer = AsyncMock()
+
+ async def dict_invoke(user, config, arguments):
+ return {"result": 42}
+
+ svc.invoke = dict_invoke
+
+ if not hasattr(DynamicToolService, "tool_service_metric"):
+ DynamicToolService.tool_service_metric = MagicMock()
+
+ msg = MagicMock()
+ msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
+ msg.properties.return_value = {"id": "r2"}
+
+ await svc.on_request(msg, MagicMock(), None)
+
+ sent = svc.producer.send.call_args[0][0]
+ assert json.loads(sent.response) == {"result": 42}
+
+ @pytest.mark.asyncio
+ async def test_on_request_error_sends_error_response(self):
+ """Exception in invoke → error response sent."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+ svc.id = "test-svc"
+ svc.producer = AsyncMock()
+
+ async def failing_invoke(user, config, arguments):
+ raise ValueError("bad input")
+
+ svc.invoke = failing_invoke
+
+ msg = MagicMock()
+ msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
+ msg.properties.return_value = {"id": "r3"}
+
+ await svc.on_request(msg, MagicMock(), None)
+
+ sent = svc.producer.send.call_args[0][0]
+ assert sent.error is not None
+ assert sent.error.type == "tool-service-error"
+ assert "bad input" in sent.error.message
+ assert sent.response == ""
+
+ @pytest.mark.asyncio
+ async def test_on_request_too_many_requests_propagates(self):
+ """TooManyRequests should propagate (not caught as error response)."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+ svc.id = "test-svc"
+ svc.producer = AsyncMock()
+
+ async def rate_limited_invoke(user, config, arguments):
+ raise TooManyRequests("rate limited")
+
+ svc.invoke = rate_limited_invoke
+
+ msg = MagicMock()
+ msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
+ msg.properties.return_value = {"id": "r4"}
+
+ with pytest.raises(TooManyRequests):
+ await svc.on_request(msg, MagicMock(), None)
+
+ @pytest.mark.asyncio
+ async def test_on_request_preserves_message_id(self):
+ """Response should include the original message id in properties."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+ svc.id = "test-svc"
+ svc.producer = AsyncMock()
+
+ async def ok_invoke(user, config, arguments):
+ return "ok"
+
+ svc.invoke = ok_invoke
+
+ if not hasattr(DynamicToolService, "tool_service_metric"):
+ DynamicToolService.tool_service_metric = MagicMock()
+
+ msg = MagicMock()
+ msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
+ msg.properties.return_value = {"id": "unique-42"}
+
+ await svc.on_request(msg, MagicMock(), None)
+
+ props = svc.producer.send.call_args[1]["properties"]
+ assert props["id"] == "unique-42"
+
+
+# ---------------------------------------------------------------------------
+# ToolService (flow-based) tests
+# ---------------------------------------------------------------------------
+
+class TestToolServiceOnRequest:
+
+ @pytest.mark.asyncio
+ async def test_string_response_sent_as_text(self):
+ """String return from invoke_tool → ToolResponse.text is set."""
+ from trustgraph.base.tool_service import ToolService
+
+ svc = ToolService.__new__(ToolService)
+ svc.id = "test-tool"
+
+ async def mock_invoke(name, params):
+ return "tool result"
+
+ svc.invoke_tool = mock_invoke
+
+ if not hasattr(ToolService, "tool_invocation_metric"):
+ ToolService.tool_invocation_metric = MagicMock()
+
+ mock_response_pub = AsyncMock()
+ flow = MagicMock()
+ flow.name = "test-flow"
+
+ def flow_callable(name):
+ if name == "response":
+ return mock_response_pub
+ return MagicMock()
+
+ flow_callable.producer = {"response": mock_response_pub}
+ flow_callable.name = "test-flow"
+
+ msg = MagicMock()
+ msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
+ msg.properties.return_value = {"id": "t1"}
+
+ await svc.on_request(msg, MagicMock(), flow_callable)
+
+ sent = mock_response_pub.send.call_args[0][0]
+ assert isinstance(sent, ToolResponse)
+ assert sent.text == "tool result"
+ assert sent.object is None
+
+ @pytest.mark.asyncio
+ async def test_dict_response_sent_as_json_object(self):
+ """Dict return from invoke_tool → ToolResponse.object is JSON."""
+ from trustgraph.base.tool_service import ToolService
+
+ svc = ToolService.__new__(ToolService)
+ svc.id = "test-tool"
+
+ async def mock_invoke(name, params):
+ return {"data": [1, 2, 3]}
+
+ svc.invoke_tool = mock_invoke
+
+ if not hasattr(ToolService, "tool_invocation_metric"):
+ ToolService.tool_invocation_metric = MagicMock()
+
+ mock_response_pub = AsyncMock()
+ flow = MagicMock()
+
+ def flow_callable(name):
+ if name == "response":
+ return mock_response_pub
+ return MagicMock()
+
+ flow_callable.producer = {"response": mock_response_pub}
+ flow_callable.name = "test-flow"
+
+ msg = MagicMock()
+ msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
+ msg.properties.return_value = {"id": "t2"}
+
+ await svc.on_request(msg, MagicMock(), flow_callable)
+
+ sent = mock_response_pub.send.call_args[0][0]
+ assert sent.text is None
+ assert json.loads(sent.object) == {"data": [1, 2, 3]}
+
+ @pytest.mark.asyncio
+ async def test_error_sends_error_response(self):
+ """Exception in invoke_tool → error response via flow producer."""
+ from trustgraph.base.tool_service import ToolService
+
+ svc = ToolService.__new__(ToolService)
+ svc.id = "test-tool"
+
+ async def failing_invoke(name, params):
+ raise RuntimeError("tool broke")
+
+ svc.invoke_tool = failing_invoke
+
+ mock_response_pub = AsyncMock()
+ flow = MagicMock()
+
+ def flow_callable(name):
+ return MagicMock()
+
+ flow_callable.producer = {"response": mock_response_pub}
+ flow_callable.name = "test-flow"
+
+ msg = MagicMock()
+ msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
+ msg.properties.return_value = {"id": "t3"}
+
+ await svc.on_request(msg, MagicMock(), flow_callable)
+
+ sent = mock_response_pub.send.call_args[0][0]
+ assert sent.error is not None
+ assert sent.error.type == "tool-error"
+ assert "tool broke" in sent.error.message
+
+ @pytest.mark.asyncio
+ async def test_too_many_requests_propagates(self):
+ """TooManyRequests should propagate from ToolService.on_request."""
+ from trustgraph.base.tool_service import ToolService
+
+ svc = ToolService.__new__(ToolService)
+ svc.id = "test-tool"
+
+ async def rate_limited(name, params):
+ raise TooManyRequests("slow down")
+
+ svc.invoke_tool = rate_limited
+
+ msg = MagicMock()
+ msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
+ msg.properties.return_value = {"id": "t4"}
+
+ flow = MagicMock()
+ flow.producer = {"response": AsyncMock()}
+ flow.name = "test-flow"
+
+ with pytest.raises(TooManyRequests):
+ await svc.on_request(msg, MagicMock(), flow)
+
+ @pytest.mark.asyncio
+ async def test_parameters_json_parsed(self):
+ """Parameters should be JSON-parsed before passing to invoke_tool."""
+ from trustgraph.base.tool_service import ToolService
+
+ svc = ToolService.__new__(ToolService)
+ svc.id = "test-tool"
+
+ received = {}
+
+ async def capture_invoke(name, params):
+ received["name"] = name
+ received["params"] = params
+ return "ok"
+
+ svc.invoke_tool = capture_invoke
+
+ if not hasattr(ToolService, "tool_invocation_metric"):
+ ToolService.tool_invocation_metric = MagicMock()
+
+ mock_pub = AsyncMock()
+ flow = lambda name: mock_pub
+ flow.producer = {"response": mock_pub}
+ flow.name = "f"
+
+ msg = MagicMock()
+ msg.value.return_value = ToolRequest(
+ name="search",
+ parameters='{"query": "test", "limit": 10}',
+ )
+ msg.properties.return_value = {"id": "t5"}
+
+ await svc.on_request(msg, MagicMock(), flow)
+
+ assert received["name"] == "search"
+ assert received["params"] == {"query": "test", "limit": 10}
+
+
+# ---------------------------------------------------------------------------
+# ToolServiceClient tests
+# ---------------------------------------------------------------------------
+
+class TestToolServiceClientCall:
+
+ @pytest.mark.asyncio
+ async def test_call_sends_request_and_returns_response(self):
+ """call() should send ToolServiceRequest and return response string."""
+ from trustgraph.base.tool_service_client import ToolServiceClient
+
+ client = ToolServiceClient.__new__(ToolServiceClient)
+ client.request = AsyncMock(return_value=ToolServiceResponse(
+ error=None, response="joke result", end_of_stream=True,
+ ))
+
+ result = await client.call(
+ user="alice",
+ config={"style": "pun"},
+ arguments={"topic": "cats"},
+ )
+
+ assert result == "joke result"
+
+ req = client.request.call_args[0][0]
+ assert isinstance(req, ToolServiceRequest)
+ assert req.user == "alice"
+ assert json.loads(req.config) == {"style": "pun"}
+ assert json.loads(req.arguments) == {"topic": "cats"}
+
+ @pytest.mark.asyncio
+ async def test_call_raises_on_error(self):
+ """call() should raise RuntimeError when response has error."""
+ from trustgraph.base.tool_service_client import ToolServiceClient
+
+ client = ToolServiceClient.__new__(ToolServiceClient)
+ client.request = AsyncMock(return_value=ToolServiceResponse(
+ error=Error(type="tool-service-error", message="service down"),
+ response="",
+ ))
+
+ with pytest.raises(RuntimeError, match="service down"):
+ await client.call(user="u", config={}, arguments={})
+
+ @pytest.mark.asyncio
+ async def test_call_empty_config_sends_empty_json(self):
+ """Empty config/arguments should be sent as '{}'."""
+ from trustgraph.base.tool_service_client import ToolServiceClient
+
+ client = ToolServiceClient.__new__(ToolServiceClient)
+ client.request = AsyncMock(return_value=ToolServiceResponse(
+ error=None, response="ok",
+ ))
+
+ await client.call(user="u", config=None, arguments=None)
+
+ req = client.request.call_args[0][0]
+ assert req.config == "{}"
+ assert req.arguments == "{}"
+
+ @pytest.mark.asyncio
+ async def test_call_passes_timeout(self):
+ """call() should forward timeout to underlying request."""
+ from trustgraph.base.tool_service_client import ToolServiceClient
+
+ client = ToolServiceClient.__new__(ToolServiceClient)
+ client.request = AsyncMock(return_value=ToolServiceResponse(
+ error=None, response="ok",
+ ))
+
+ await client.call(user="u", config={}, arguments={}, timeout=30)
+
+ _, kwargs = client.request.call_args
+ assert kwargs["timeout"] == 30
+
+
+class TestToolServiceClientStreaming:
+
+ @pytest.mark.asyncio
+ async def test_call_streaming_collects_chunks(self):
+ """call_streaming should accumulate chunks and return full result."""
+ from trustgraph.base.tool_service_client import ToolServiceClient
+
+ client = ToolServiceClient.__new__(ToolServiceClient)
+
+ # Simulate streaming: request() calls recipient with each chunk
+ chunks = [
+ ToolServiceResponse(error=None, response="chunk1", end_of_stream=False),
+ ToolServiceResponse(error=None, response="chunk2", end_of_stream=True),
+ ]
+
+ async def mock_request(req, timeout=600, recipient=None):
+ for chunk in chunks:
+ done = await recipient(chunk)
+ if done:
+ break
+
+ client.request = mock_request
+
+ received = []
+
+ async def callback(text):
+ received.append(text)
+
+ result = await client.call_streaming(
+ user="u", config={}, arguments={}, callback=callback,
+ )
+
+ assert result == "chunk1chunk2"
+ assert received == ["chunk1", "chunk2"]
+
+ @pytest.mark.asyncio
+ async def test_call_streaming_raises_on_error(self):
+ """call_streaming should raise RuntimeError on error chunk."""
+ from trustgraph.base.tool_service_client import ToolServiceClient
+
+ client = ToolServiceClient.__new__(ToolServiceClient)
+
+ async def mock_request(req, timeout=600, recipient=None):
+ error_resp = ToolServiceResponse(
+ error=Error(type="tool-service-error", message="stream failed"),
+ response="",
+ end_of_stream=True,
+ )
+ await recipient(error_resp)
+
+ client.request = mock_request
+
+ with pytest.raises(RuntimeError, match="stream failed"):
+ await client.call_streaming(
+ user="u", config={}, arguments={},
+ callback=AsyncMock(),
+ )
+
+ @pytest.mark.asyncio
+ async def test_call_streaming_skips_empty_response(self):
+ """Empty response chunks should not be added to result."""
+ from trustgraph.base.tool_service_client import ToolServiceClient
+
+ client = ToolServiceClient.__new__(ToolServiceClient)
+
+ chunks = [
+ ToolServiceResponse(error=None, response="", end_of_stream=False),
+ ToolServiceResponse(error=None, response="data", end_of_stream=True),
+ ]
+
+ async def mock_request(req, timeout=600, recipient=None):
+ for chunk in chunks:
+ done = await recipient(chunk)
+ if done:
+ break
+
+ client.request = mock_request
+
+ received = []
+
+ async def callback(text):
+ received.append(text)
+
+ result = await client.call_streaming(
+ user="u", config={}, arguments={}, callback=callback,
+ )
+
+ # Empty response is falsy, so callback shouldn't be called for it
+ assert result == "data"
+ assert received == ["data"]
+
+
+# ---------------------------------------------------------------------------
+# Multi-tenancy
+# ---------------------------------------------------------------------------
+
+class TestMultiTenancy:
+
+ @pytest.mark.asyncio
+ async def test_user_propagated_to_invoke(self):
+ """User from request should reach the invoke method."""
+ from trustgraph.base.dynamic_tool_service import DynamicToolService
+
+ svc = DynamicToolService.__new__(DynamicToolService)
+ svc.id = "test"
+ svc.producer = AsyncMock()
+
+ users_seen = []
+
+ async def tracking(user, config, arguments):
+ users_seen.append(user)
+ return "ok"
+
+ svc.invoke = tracking
+
+ if not hasattr(DynamicToolService, "tool_service_metric"):
+ DynamicToolService.tool_service_metric = MagicMock()
+
+ for u in ["tenant-a", "tenant-b", "tenant-c"]:
+ msg = MagicMock()
+ msg.value.return_value = ToolServiceRequest(
+ user=u, config="{}", arguments="{}",
+ )
+ msg.properties.return_value = {"id": f"req-{u}"}
+ await svc.on_request(msg, MagicMock(), None)
+
+ assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
+
+ @pytest.mark.asyncio
+ async def test_client_sends_user_in_request(self):
+ """ToolServiceClient.call should include user in request."""
+ from trustgraph.base.tool_service_client import ToolServiceClient
+
+ client = ToolServiceClient.__new__(ToolServiceClient)
+ client.request = AsyncMock(return_value=ToolServiceResponse(
+ error=None, response="ok",
+ ))
+
+ await client.call(user="isolated-tenant", config={}, arguments={})
+
+ req = client.request.call_args[0][0]
+ assert req.user == "isolated-tenant"
diff --git a/tests/unit/test_concurrency/__init__.py b/tests/unit/test_concurrency/__init__.py
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/tests/unit/test_concurrency/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py
new file mode 100644
index 00000000..32a6559b
--- /dev/null
+++ b/tests/unit/test_concurrency/test_consumer_concurrency.py
@@ -0,0 +1,286 @@
+"""
+Tests for Consumer concurrency: TaskGroup-based concurrent message processing,
+rate-limit retry with backpressure, and message acknowledgement.
+"""
+
+import asyncio
+import time
+
+import pytest
+from unittest.mock import MagicMock, AsyncMock, patch
+
+from trustgraph.base.consumer import Consumer
+from trustgraph.exceptions import TooManyRequests
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_consumer(
+ concurrency=1,
+ handler=None,
+ rate_limit_retry_time=0.01,
+ rate_limit_timeout=1,
+):
+ """Create a Consumer with mocked infrastructure."""
+ taskgroup = MagicMock()
+ flow = MagicMock()
+ backend = MagicMock()
+ schema = MagicMock()
+ handler = handler or AsyncMock()
+
+ consumer = Consumer(
+ taskgroup=taskgroup,
+ flow=flow,
+ backend=backend,
+ topic="test-topic",
+ subscriber="test-sub",
+ schema=schema,
+ handler=handler,
+ rate_limit_retry_time=rate_limit_retry_time,
+ rate_limit_timeout=rate_limit_timeout,
+ concurrency=concurrency,
+ )
+
+ return consumer
+
+
+def _make_msg():
+ """Create a mock Pulsar message."""
+ return MagicMock()
+
+
+# ---------------------------------------------------------------------------
+# Concurrency configuration tests
+# ---------------------------------------------------------------------------
+
+class TestConcurrencyConfiguration:
+
+ def test_default_concurrency_is_1(self):
+ consumer = _make_consumer()
+ assert consumer.concurrency == 1
+
+ def test_custom_concurrency(self):
+ consumer = _make_consumer(concurrency=10)
+ assert consumer.concurrency == 10
+
+ def test_concurrency_stored(self):
+ for n in [1, 5, 20, 100]:
+ consumer = _make_consumer(concurrency=n)
+ assert consumer.concurrency == n
+
+
+class TestTaskGroupConcurrency:
+
+ @pytest.mark.asyncio
+ async def test_creates_n_concurrent_tasks(self):
+ """consumer_run should create exactly N concurrent consume_from_queue tasks."""
+ concurrency = 5
+ consumer = _make_consumer(concurrency=concurrency)
+
+ # Track how many consume_from_queue calls are made
+ call_count = 0
+ original_running = True
+
+ async def mock_consume():
+ nonlocal call_count
+ call_count += 1
+ # Wait a bit to let all tasks start, then signal stop
+ await asyncio.sleep(0.05)
+ consumer.running = False
+
+ consumer.consume_from_queue = mock_consume
+
+ # Mock the backend.create_consumer
+ consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
+
+ # Run consumer_run - it will create TaskGroup with N tasks
+ consumer.running = True
+ await consumer.consumer_run()
+
+ assert call_count == concurrency
+
+ @pytest.mark.asyncio
+ async def test_single_concurrency_creates_one_task(self):
+ """With concurrency=1, only one consume_from_queue task is created."""
+ consumer = _make_consumer(concurrency=1)
+ call_count = 0
+
+ async def mock_consume():
+ nonlocal call_count
+ call_count += 1
+ await asyncio.sleep(0.01)
+ consumer.running = False
+
+ consumer.consume_from_queue = mock_consume
+ consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
+
+ consumer.running = True
+ await consumer.consumer_run()
+
+ assert call_count == 1
+
+
+# ---------------------------------------------------------------------------
+# Rate-limit retry tests
+# ---------------------------------------------------------------------------
+
+class TestRateLimitRetry:
+
+ @pytest.mark.asyncio
+ async def test_rate_limit_retries_then_succeeds(self):
+ """TooManyRequests should cause retry, then succeed on next attempt."""
+ call_count = 0
+
+ async def handler_with_retry(msg, consumer_ref, flow):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ raise TooManyRequests("rate limited")
+ # Second call succeeds
+
+ consumer = _make_consumer(
+ handler=handler_with_retry,
+ rate_limit_retry_time=0.01,
+ )
+ mock_msg = _make_msg()
+ consumer.consumer = MagicMock()
+
+ await consumer.handle_one_from_queue(mock_msg)
+
+ assert call_count == 2
+ consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
+
+ @pytest.mark.asyncio
+ async def test_rate_limit_timeout_negative_acks(self):
+ """If rate limit retries exhaust the timeout, message is negative-acked."""
+ async def always_rate_limited(msg, consumer_ref, flow):
+ raise TooManyRequests("rate limited")
+
+ consumer = _make_consumer(
+ handler=always_rate_limited,
+ rate_limit_retry_time=0.01,
+ rate_limit_timeout=0.05,
+ )
+ mock_msg = _make_msg()
+ consumer.consumer = MagicMock()
+
+ await consumer.handle_one_from_queue(mock_msg)
+
+ consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
+ consumer.consumer.acknowledge.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_non_rate_limit_error_negative_acks_immediately(self):
+ """Non-TooManyRequests errors should negative-ack immediately (no retry)."""
+ call_count = 0
+
+ async def failing_handler(msg, consumer_ref, flow):
+ nonlocal call_count
+ call_count += 1
+ raise ValueError("bad data")
+
+ consumer = _make_consumer(handler=failing_handler)
+ mock_msg = _make_msg()
+ consumer.consumer = MagicMock()
+
+ await consumer.handle_one_from_queue(mock_msg)
+
+ assert call_count == 1
+ consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
+
+ @pytest.mark.asyncio
+ async def test_successful_message_acknowledged(self):
+ """Successfully processed messages are acknowledged."""
+ consumer = _make_consumer(handler=AsyncMock())
+ mock_msg = _make_msg()
+ consumer.consumer = MagicMock()
+
+ await consumer.handle_one_from_queue(mock_msg)
+
+ consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
+
+
+# ---------------------------------------------------------------------------
+# Metrics integration
+# ---------------------------------------------------------------------------
+
+class TestMetricsIntegration:
+
+ @pytest.mark.asyncio
+ async def test_success_metric_on_success(self):
+ consumer = _make_consumer(handler=AsyncMock())
+ mock_msg = _make_msg()
+ consumer.consumer = MagicMock()
+
+ mock_metrics = MagicMock()
+ mock_metrics.record_time.return_value.__enter__ = MagicMock()
+ mock_metrics.record_time.return_value.__exit__ = MagicMock()
+ consumer.metrics = mock_metrics
+
+ await consumer.handle_one_from_queue(mock_msg)
+
+ mock_metrics.process.assert_called_once_with("success")
+
+ @pytest.mark.asyncio
+ async def test_error_metric_on_failure(self):
+ async def failing(msg, c, f):
+ raise ValueError("fail")
+
+ consumer = _make_consumer(handler=failing)
+ mock_msg = _make_msg()
+ consumer.consumer = MagicMock()
+
+ mock_metrics = MagicMock()
+ consumer.metrics = mock_metrics
+
+ await consumer.handle_one_from_queue(mock_msg)
+
+ mock_metrics.process.assert_called_once_with("error")
+
+ @pytest.mark.asyncio
+ async def test_rate_limit_metric_on_too_many_requests(self):
+ call_count = 0
+
+ async def handler(msg, c, f):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ raise TooManyRequests("limited")
+
+ consumer = _make_consumer(
+ handler=handler,
+ rate_limit_retry_time=0.01,
+ )
+ mock_msg = _make_msg()
+ consumer.consumer = MagicMock()
+
+ mock_metrics = MagicMock()
+ mock_metrics.record_time.return_value.__enter__ = MagicMock()
+ mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
+ consumer.metrics = mock_metrics
+
+ await consumer.handle_one_from_queue(mock_msg)
+
+ mock_metrics.rate_limit.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# Stop / running flag
+# ---------------------------------------------------------------------------
+
+class TestStopBehaviour:
+
+ @pytest.mark.asyncio
+ async def test_stop_sets_running_false(self):
+ consumer = _make_consumer()
+ consumer.running = True
+
+ await consumer.stop()
+
+ assert consumer.running is False
+
+ def test_initial_running_state(self):
+ consumer = _make_consumer()
+ assert consumer.running is True
diff --git a/tests/unit/test_concurrency/test_dispatcher_semaphore.py b/tests/unit/test_concurrency/test_dispatcher_semaphore.py
new file mode 100644
index 00000000..6a1ae8ab
--- /dev/null
+++ b/tests/unit/test_concurrency/test_dispatcher_semaphore.py
@@ -0,0 +1,136 @@
+"""
+Tests for MessageDispatcher semaphore-based concurrency enforcement.
+
+Verifies that the dispatcher limits concurrent message processing to
+max_workers via asyncio.Semaphore.
+"""
+
+import asyncio
+
+import pytest
+from unittest.mock import MagicMock, AsyncMock, patch
+
+from trustgraph.rev_gateway.dispatcher import MessageDispatcher
+
+
+class TestSemaphoreEnforcement:
+
+ @pytest.mark.asyncio
+ async def test_semaphore_limits_concurrent_processing(self):
+ """Only max_workers messages should be processed concurrently."""
+ max_workers = 2
+ dispatcher = MessageDispatcher(max_workers=max_workers)
+
+ concurrent_count = 0
+ max_concurrent = 0
+ processing_event = asyncio.Event()
+
+ async def slow_process(message):
+ nonlocal concurrent_count, max_concurrent
+ concurrent_count += 1
+ max_concurrent = max(max_concurrent, concurrent_count)
+ await asyncio.sleep(0.05)
+ concurrent_count -= 1
+ return {"id": message.get("id"), "response": {"ok": True}}
+
+ dispatcher._process_message = slow_process
+
+ # Launch more tasks than max_workers
+ messages = [
+ {"id": f"msg-{i}", "service": "test", "request": {}}
+ for i in range(5)
+ ]
+
+ tasks = [
+ asyncio.create_task(dispatcher.handle_message(m))
+ for m in messages
+ ]
+
+ await asyncio.gather(*tasks)
+
+ # At no point should more than max_workers have been active
+ assert max_concurrent <= max_workers
+
+ @pytest.mark.asyncio
+ async def test_semaphore_value_matches_max_workers(self):
+ for n in [1, 5, 20]:
+ dispatcher = MessageDispatcher(max_workers=n)
+ assert dispatcher.semaphore._value == n
+
+ @pytest.mark.asyncio
+ async def test_active_tasks_tracked(self):
+ """Active tasks should be added/removed during processing."""
+ dispatcher = MessageDispatcher(max_workers=5)
+
+ task_was_tracked = False
+
+ original_process = dispatcher._process_message
+
+ async def tracking_process(message):
+ nonlocal task_was_tracked
+ # During processing, our task should be in active_tasks
+ if len(dispatcher.active_tasks) > 0:
+ task_was_tracked = True
+ return {"id": message.get("id"), "response": {"ok": True}}
+
+ dispatcher._process_message = tracking_process
+
+ await dispatcher.handle_message(
+ {"id": "test", "service": "test", "request": {}}
+ )
+
+ assert task_was_tracked
+ # After completion, task should be discarded
+ assert len(dispatcher.active_tasks) == 0
+
+ @pytest.mark.asyncio
+ async def test_semaphore_released_on_error(self):
+ """Semaphore should be released even if processing raises."""
+ dispatcher = MessageDispatcher(max_workers=2)
+
+ async def failing_process(message):
+ raise RuntimeError("process failed")
+
+ dispatcher._process_message = failing_process
+
+ # Should not deadlock — semaphore must be released on error
+ with pytest.raises(RuntimeError):
+ await dispatcher.handle_message(
+ {"id": "test", "service": "test", "request": {}}
+ )
+
+ # Semaphore should be back at max
+ assert dispatcher.semaphore._value == 2
+
+ @pytest.mark.asyncio
+ async def test_single_worker_serializes_processing(self):
+ """With max_workers=1, messages are processed one at a time."""
+ dispatcher = MessageDispatcher(max_workers=1)
+
+ order = []
+
+ async def ordered_process(message):
+ msg_id = message["id"]
+ order.append(f"start-{msg_id}")
+ await asyncio.sleep(0.02)
+ order.append(f"end-{msg_id}")
+ return {"id": msg_id, "response": {"ok": True}}
+
+ dispatcher._process_message = ordered_process
+
+ messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
+ tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
+ await asyncio.gather(*tasks)
+
+ # With semaphore=1, each message should complete before next starts
+ # Check that no two "start" entries appear without an intervening "end"
+ active = 0
+ max_active = 0
+ for event in order:
+ if event.startswith("start"):
+ active += 1
+ max_active = max(max_active, active)
+ elif event.startswith("end"):
+ active -= 1
+
+ assert max_active == 1
diff --git a/tests/unit/test_concurrency/test_graph_rag_concurrency.py b/tests/unit/test_concurrency/test_graph_rag_concurrency.py
new file mode 100644
index 00000000..8287427b
--- /dev/null
+++ b/tests/unit/test_concurrency/test_graph_rag_concurrency.py
@@ -0,0 +1,268 @@
+"""
+Tests for Graph RAG concurrent query execution.
+
+Covers: execute_batch_triple_queries concurrent task spawning,
+exception handling in gather, and result aggregation.
+"""
+
+import asyncio
+
+import pytest
+from unittest.mock import MagicMock, AsyncMock
+
+from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_query(
+ triples_client=None,
+ entity_limit=50,
+ triple_limit=30,
+ max_subgraph_size=1000,
+ max_path_length=2,
+):
+ """Create a Query object with mocked rag dependencies."""
+ rag = MagicMock()
+ rag.triples_client = triples_client or AsyncMock()
+ rag.label_cache = LRUCacheWithTTL()
+
+ query = Query(
+ rag=rag,
+ user="test-user",
+ collection="test-collection",
+ verbose=False,
+ entity_limit=entity_limit,
+ triple_limit=triple_limit,
+ max_subgraph_size=max_subgraph_size,
+ max_path_length=max_path_length,
+ )
+ return query
+
+
+def _make_triple(s, p, o):
+ """Create a simple mock triple."""
+ t = MagicMock()
+ t.s = s
+ t.p = p
+ t.o = o
+ return t
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+class TestBatchTripleQueries:
+
+ @pytest.mark.asyncio
+ async def test_three_queries_per_entity(self):
+ """Each entity should generate 3 concurrent queries (s, p, o positions)."""
+ client = AsyncMock()
+ client.query_stream = AsyncMock(return_value=[])
+ query = _make_query(triples_client=client)
+
+ entities = ["entity-1"]
+ await query.execute_batch_triple_queries(entities, limit_per_entity=10)
+
+ assert client.query_stream.call_count == 3
+
+ @pytest.mark.asyncio
+ async def test_multiple_entities_multiply_queries(self):
+ """N entities should produce N*3 concurrent queries."""
+ client = AsyncMock()
+ client.query_stream = AsyncMock(return_value=[])
+ query = _make_query(triples_client=client)
+
+ entities = ["e1", "e2", "e3"]
+ await query.execute_batch_triple_queries(entities, limit_per_entity=10)
+
+ assert client.query_stream.call_count == 9 # 3 * 3
+
+ @pytest.mark.asyncio
+ async def test_queries_executed_concurrently(self):
+ """All queries should run concurrently via asyncio.gather."""
+ concurrent_count = 0
+ max_concurrent = 0
+
+ async def tracking_query(**kwargs):
+ nonlocal concurrent_count, max_concurrent
+ concurrent_count += 1
+ max_concurrent = max(max_concurrent, concurrent_count)
+ await asyncio.sleep(0.02)
+ concurrent_count -= 1
+ return []
+
+ client = AsyncMock()
+ client.query_stream = tracking_query
+ query = _make_query(triples_client=client)
+
+ entities = ["e1", "e2", "e3"]
+ await query.execute_batch_triple_queries(entities, limit_per_entity=5)
+
+ # All 9 queries should have run concurrently
+ assert max_concurrent == 9
+
+ @pytest.mark.asyncio
+ async def test_results_aggregated(self):
+ """Results from all queries should be combined into a single list."""
+ triple_a = _make_triple("a", "p", "b")
+ triple_b = _make_triple("c", "p", "d")
+
+ call_count = 0
+
+ async def alternating_results(**kwargs):
+ nonlocal call_count
+ call_count += 1
+ if call_count % 2 == 0:
+ return [triple_a]
+ return [triple_b]
+
+ client = AsyncMock()
+ client.query_stream = alternating_results
+ query = _make_query(triples_client=client)
+
+ result = await query.execute_batch_triple_queries(
+ ["e1"], limit_per_entity=10
+ )
+
+ # 3 queries, alternating results
+ assert len(result) == 3
+
+ @pytest.mark.asyncio
+ async def test_exception_in_one_query_does_not_block_others(self):
+ """If one query raises, other results are still collected."""
+ good_triple = _make_triple("a", "p", "b")
+ call_count = 0
+
+ async def mixed_results(**kwargs):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 2:
+ raise RuntimeError("query failed")
+ return [good_triple]
+
+ client = AsyncMock()
+ client.query_stream = mixed_results
+ query = _make_query(triples_client=client)
+
+ result = await query.execute_batch_triple_queries(
+ ["e1"], limit_per_entity=10
+ )
+
+ # 3 queries: 2 succeed, 1 fails → 2 triples
+ assert len(result) == 2
+
+ @pytest.mark.asyncio
+ async def test_none_results_filtered(self):
+ """None results from queries should be filtered out."""
+ call_count = 0
+
+ async def sometimes_none(**kwargs):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ return None
+ return [_make_triple("a", "p", "b")]
+
+ client = AsyncMock()
+ client.query_stream = sometimes_none
+ query = _make_query(triples_client=client)
+
+ result = await query.execute_batch_triple_queries(
+ ["e1"], limit_per_entity=10
+ )
+
+ # 3 queries: 1 returns None, 2 return triples
+ assert len(result) == 2
+
+ @pytest.mark.asyncio
+ async def test_empty_entities_no_queries(self):
+ """Empty entity list should produce no queries."""
+ client = AsyncMock()
+ client.query_stream = AsyncMock(return_value=[])
+ query = _make_query(triples_client=client)
+
+ result = await query.execute_batch_triple_queries([], limit_per_entity=10)
+
+ assert result == []
+ client.query_stream.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_query_params_correct(self):
+ """Each query should use correct s/p/o positions and params."""
+ client = AsyncMock()
+ client.query_stream = AsyncMock(return_value=[])
+ query = _make_query(triples_client=client)
+
+ entities = ["ent-1"]
+ await query.execute_batch_triple_queries(entities, limit_per_entity=15)
+
+ calls = client.query_stream.call_args_list
+ assert len(calls) == 3
+
+ # First call: s=entity, p=None, o=None
+ assert calls[0].kwargs["s"] == "ent-1"
+ assert calls[0].kwargs["p"] is None
+ assert calls[0].kwargs["o"] is None
+ assert calls[0].kwargs["limit"] == 15
+ assert calls[0].kwargs["user"] == "test-user"
+ assert calls[0].kwargs["collection"] == "test-collection"
+ assert calls[0].kwargs["batch_size"] == 20
+
+ # Second call: s=None, p=entity, o=None
+ assert calls[1].kwargs["s"] is None
+ assert calls[1].kwargs["p"] == "ent-1"
+ assert calls[1].kwargs["o"] is None
+
+ # Third call: s=None, p=None, o=entity
+ assert calls[2].kwargs["s"] is None
+ assert calls[2].kwargs["p"] is None
+ assert calls[2].kwargs["o"] == "ent-1"
+
+
+class TestLRUCacheWithTTL:
+
+ def test_put_and_get(self):
+ cache = LRUCacheWithTTL(max_size=10, ttl=60)
+ cache.put("key1", "value1")
+ assert cache.get("key1") == "value1"
+
+ def test_get_missing_returns_none(self):
+ cache = LRUCacheWithTTL()
+ assert cache.get("nonexistent") is None
+
+ def test_max_size_eviction(self):
+ cache = LRUCacheWithTTL(max_size=2, ttl=60)
+ cache.put("a", 1)
+ cache.put("b", 2)
+ cache.put("c", 3) # Should evict "a"
+ assert cache.get("a") is None
+ assert cache.get("b") == 2
+ assert cache.get("c") == 3
+
+ def test_lru_order(self):
+ cache = LRUCacheWithTTL(max_size=2, ttl=60)
+ cache.put("a", 1)
+ cache.put("b", 2)
+ cache.get("a") # Access "a" — now "b" is LRU
+ cache.put("c", 3) # Should evict "b"
+ assert cache.get("a") == 1
+ assert cache.get("b") is None
+ assert cache.get("c") == 3
+
+ def test_ttl_expiration(self):
+ cache = LRUCacheWithTTL(max_size=10, ttl=0) # TTL=0 means instant expiry
+ cache.put("key", "value")
+ # With TTL=0, any time check > 0 means expired
+ import time
+ time.sleep(0.01)
+ assert cache.get("key") is None
+
+ def test_update_existing_key(self):
+ cache = LRUCacheWithTTL(max_size=10, ttl=60)
+ cache.put("key", "v1")
+ cache.put("key", "v2")
+ assert cache.get("key") == "v2"
diff --git a/tests/unit/test_direct/test_entity_centric_write_amplification.py b/tests/unit/test_direct/test_entity_centric_write_amplification.py
new file mode 100644
index 00000000..1c9ad1a8
--- /dev/null
+++ b/tests/unit/test_direct/test_entity_centric_write_amplification.py
@@ -0,0 +1,441 @@
+"""
+Tests for entity-centric KG write amplification, delete collection batching,
+in-partition filtering, and term type metadata round-trips.
+
+Complements test_entity_centric_kg.py with deeper verification of the
+2-table schema mechanics.
+"""
+
+import pytest
+from unittest.mock import MagicMock, patch, call
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+@pytest.fixture
+def mock_cassandra():
+ """Provide mocked Cassandra cluster, session, and BatchStatement."""
+ with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cls, \
+ patch('trustgraph.direct.cassandra_kg.BatchStatement') as mock_batch_cls:
+
+ mock_cluster = MagicMock()
+ mock_session = MagicMock()
+ mock_cluster.connect.return_value = mock_session
+ mock_cls.return_value = mock_cluster
+
+ # Track batch.add calls per batch instance
+ batches = []
+
+ def make_batch():
+ batch = MagicMock()
+ batch._adds = []
+ original_add = batch.add
+
+ def tracking_add(stmt, params):
+ batch._adds.append((stmt, params))
+
+ batch.add = tracking_add
+ batches.append(batch)
+ return batch
+
+ mock_batch_cls.side_effect = make_batch
+
+ yield {
+ "cluster_cls": mock_cls,
+ "cluster": mock_cluster,
+ "session": mock_session,
+ "batch_cls": mock_batch_cls,
+ "batches": batches,
+ }
+
+
+@pytest.fixture
+def entity_kg(mock_cassandra):
+ """Create an EntityCentricKnowledgeGraph with mocked Cassandra."""
+ from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
+ kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_ks')
+ return kg, mock_cassandra
+
+
+# ---------------------------------------------------------------------------
+# Write amplification: row count verification
+# ---------------------------------------------------------------------------
+
+class TestWriteAmplification:
+
+ def test_uri_object_produces_4_entity_rows_plus_collection(self, entity_kg):
+ """URI object → S + P + O + G-if-non-default entity rows + 1 collection row."""
+ kg, ctx = entity_kg
+ ctx["batches"].clear()
+
+ kg.insert(
+ collection='col',
+ s='http://ex.org/Alice',
+ p='http://ex.org/knows',
+ o='http://ex.org/Bob',
+ g='http://ex.org/g1',
+ otype='u',
+ )
+
+ # Should be exactly one batch
+ assert len(ctx["batches"]) == 1
+ batch = ctx["batches"][0]
+
+ # 4 entity rows (S, P, O, G) + 1 collection row = 5
+ assert len(batch._adds) == 5
+
+ # Check roles assigned
+ roles = [params[2] for _, params in batch._adds if len(params) == 10]
+ assert 'S' in roles
+ assert 'P' in roles
+ assert 'O' in roles
+ assert 'G' in roles
+
+ def test_literal_object_produces_3_entity_rows(self, entity_kg):
+ """Literal object → S + P entity rows (no O row) + collection row."""
+ kg, ctx = entity_kg
+ ctx["batches"].clear()
+
+ kg.insert(
+ collection='col',
+ s='http://ex.org/Alice',
+ p='http://ex.org/name',
+ o='Alice Smith',
+ g=None, # default graph
+ otype='l',
+ )
+
+ batch = ctx["batches"][0]
+
+ # S + P entity rows + 1 collection = 3 (no O row for literal, no G for default)
+ assert len(batch._adds) == 3
+
+ roles = [params[2] for _, params in batch._adds if len(params) == 10]
+ assert 'S' in roles
+ assert 'P' in roles
+ assert 'O' not in roles
+ assert 'G' not in roles
+
+ def test_triple_otype_gets_object_entity_row(self, entity_kg):
+ """otype='t' (quoted triple) → object gets entity row like URI."""
+ kg, ctx = entity_kg
+ ctx["batches"].clear()
+
+ kg.insert(
+ collection='col',
+ s='http://ex.org/s',
+ p='http://ex.org/p',
+ o='{"s":{},"p":{},"o":{}}',
+ g=None,
+ otype='t',
+ )
+
+ batch = ctx["batches"][0]
+
+ # S + P + O entity rows + collection = 4 (no G for default graph)
+ assert len(batch._adds) == 4
+
+ roles = [params[2] for _, params in batch._adds if len(params) == 10]
+ assert 'O' in roles
+
+ def test_default_graph_no_g_row(self, entity_kg):
+ """Default graph (g=None) → no G entity row."""
+ kg, ctx = entity_kg
+ ctx["batches"].clear()
+
+ kg.insert(
+ collection='col',
+ s='http://ex.org/s',
+ p='http://ex.org/p',
+ o='http://ex.org/o',
+ g=None,
+ otype='u',
+ )
+
+ batch = ctx["batches"][0]
+
+ # S + P + O entity rows + collection = 4 (no G)
+ assert len(batch._adds) == 4
+ roles = [params[2] for _, params in batch._adds if len(params) == 10]
+ assert 'G' not in roles
+
+ def test_non_default_graph_gets_g_row(self, entity_kg):
+ """Non-default graph → gets G entity row."""
+ kg, ctx = entity_kg
+ ctx["batches"].clear()
+
+ kg.insert(
+ collection='col',
+ s='http://ex.org/s',
+ p='http://ex.org/p',
+ o='http://ex.org/o',
+ g='http://ex.org/graph1',
+ otype='u',
+ )
+
+ batch = ctx["batches"][0]
+
+ # S + P + O + G entity rows + collection = 5
+ assert len(batch._adds) == 5
+ roles = [params[2] for _, params in batch._adds if len(params) == 10]
+ assert 'G' in roles
+
+ def test_dtype_and_lang_passed_to_all_rows(self, entity_kg):
+ """dtype and lang should be stored in every entity row."""
+ kg, ctx = entity_kg
+ ctx["batches"].clear()
+
+ kg.insert(
+ collection='col',
+ s='http://ex.org/s',
+ p='http://ex.org/label',
+ o='thing',
+ g=None,
+ otype='l',
+ dtype='xsd:string',
+ lang='en',
+ )
+
+ batch = ctx["batches"][0]
+
+ # Check entity rows carry dtype and lang
+ for _, params in batch._adds:
+ if len(params) == 10:
+ # Entity row: (collection, entity, role, p, otype, s, o, d, dtype, lang)
+ assert params[8] == 'xsd:string'
+ assert params[9] == 'en'
+
+
+# ---------------------------------------------------------------------------
+# In-partition filtering: get_os, get_spo
+# ---------------------------------------------------------------------------
+
+class TestInPartitionFiltering:
+
+ def test_get_os_filters_by_object(self, entity_kg):
+ """get_os should filter results by matching object value."""
+ kg, ctx = entity_kg
+
+ # Simulate rows returned from subject partition (all have same s)
+ mock_rows = [
+ MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
+ d='', otype='u', dtype='', lang='',
+ s='http://ex.org/Alice'),
+ MagicMock(p='http://ex.org/likes', o='http://ex.org/Charlie',
+ d='', otype='u', dtype='', lang='',
+ s='http://ex.org/Alice'),
+ ]
+ ctx["session"].execute.return_value = mock_rows
+
+ results = kg.get_os('col', 'http://ex.org/Bob', 'http://ex.org/Alice')
+
+ # Only the Bob row should pass the filter
+ assert len(results) == 1
+ assert results[0].o == 'http://ex.org/Bob'
+ assert results[0].p == 'http://ex.org/knows'
+
+ def test_get_os_returns_empty_when_no_match(self, entity_kg):
+ """get_os should return empty list when object doesn't match any row."""
+ kg, ctx = entity_kg
+
+ mock_rows = [
+ MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
+ d='', otype='u', dtype='', lang='',
+ s='http://ex.org/Alice'),
+ ]
+ ctx["session"].execute.return_value = mock_rows
+
+ results = kg.get_os('col', 'http://ex.org/Charlie', 'http://ex.org/Alice')
+
+ assert len(results) == 0
+
+ def test_get_spo_filters_by_object(self, entity_kg):
+ """get_spo should filter results by matching object value."""
+ kg, ctx = entity_kg
+
+ mock_rows = [
+ MagicMock(o='http://ex.org/Bob', d='', otype='u', dtype='', lang=''),
+ MagicMock(o='http://ex.org/Charlie', d='', otype='u', dtype='', lang=''),
+ ]
+ ctx["session"].execute.return_value = mock_rows
+
+ results = kg.get_spo(
+ 'col', 'http://ex.org/Alice', 'http://ex.org/knows',
+ 'http://ex.org/Bob',
+ )
+
+ assert len(results) == 1
+ assert results[0].o == 'http://ex.org/Bob'
+
+ def test_get_os_with_graph_filter(self, entity_kg):
+ """get_os with specific graph should filter both object and graph."""
+ kg, ctx = entity_kg
+
+ mock_rows = [
+ MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
+ d='http://ex.org/g1', otype='u', dtype='', lang='',
+ s='http://ex.org/Alice'),
+ MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
+ d='http://ex.org/g2', otype='u', dtype='', lang='',
+ s='http://ex.org/Alice'),
+ ]
+ ctx["session"].execute.return_value = mock_rows
+
+ results = kg.get_os(
+ 'col', 'http://ex.org/Bob', 'http://ex.org/Alice',
+ g='http://ex.org/g1',
+ )
+
+ assert len(results) == 1
+ assert results[0].g == 'http://ex.org/g1'
+
+
+# ---------------------------------------------------------------------------
+# Delete collection batching
+# ---------------------------------------------------------------------------
+
+class TestDeleteCollectionBatching:
+
+ def test_extracts_unique_entities_from_quads(self, entity_kg):
+ """delete_collection should extract s, p, and URI o as entities."""
+ kg, ctx = entity_kg
+
+ mock_rows = [
+ MagicMock(d='', s='http://ex.org/A', p='http://ex.org/knows',
+ o='http://ex.org/B', otype='u', dtype='', lang=''),
+ MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name',
+ o='Alice', otype='l', dtype='', lang=''),
+ ]
+ ctx["session"].execute.return_value = mock_rows
+ ctx["batches"].clear()
+
+ kg.delete_collection('col')
+
+ # Unique entities: A, knows, B, name (literal 'Alice' excluded)
+ # The batches should include entity partition deletes
+ all_adds = []
+ for batch in ctx["batches"]:
+ all_adds.extend(batch._adds)
+
+ # We expect entity deletes + collection row deletes + metadata delete
+ # Just verify the function completes and calls execute
+ assert ctx["session"].execute.called
+
+ def test_literal_objects_not_treated_as_entities(self, entity_kg):
+ """Literal objects (otype='l') should not get entity partition deletes."""
+ kg, ctx = entity_kg
+
+ mock_rows = [
+ MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name',
+ o='Alice', otype='l', dtype='', lang=''),
+ ]
+ ctx["session"].execute.return_value = mock_rows
+ ctx["batches"].clear()
+
+ kg.delete_collection('col')
+
+ # Entity partition deletes should only include A and name, not Alice
+ entity_deletes = []
+ for batch in ctx["batches"]:
+ for _, params in batch._adds:
+ if len(params) == 2: # delete_entity_partition takes (collection, entity)
+ entity_deletes.append(params[1])
+
+ assert 'http://ex.org/A' in entity_deletes
+ assert 'http://ex.org/name' in entity_deletes
+ assert 'Alice' not in entity_deletes
+
+ def test_non_default_graph_treated_as_entity(self, entity_kg):
+ """Non-default graphs should get entity partition deletes."""
+ kg, ctx = entity_kg
+
+ mock_rows = [
+ MagicMock(d='http://ex.org/g1', s='http://ex.org/A',
+ p='http://ex.org/p', o='http://ex.org/B',
+ otype='u', dtype='', lang=''),
+ ]
+ ctx["session"].execute.return_value = mock_rows
+ ctx["batches"].clear()
+
+ kg.delete_collection('col')
+
+ entity_deletes = []
+ for batch in ctx["batches"]:
+ for _, params in batch._adds:
+ if len(params) == 2:
+ entity_deletes.append(params[1])
+
+ assert 'http://ex.org/g1' in entity_deletes
+
+ def test_empty_collection_delete_completes(self, entity_kg):
+ """Deleting an empty collection should not error."""
+ kg, ctx = entity_kg
+
+ ctx["session"].execute.return_value = []
+ ctx["batches"].clear()
+
+ # Should not raise
+ kg.delete_collection('empty-col')
+
+
+# ---------------------------------------------------------------------------
+# Term type metadata round-trip
+# ---------------------------------------------------------------------------
+
+class TestTermTypeMetadata:
+
+ def test_query_results_include_otype(self, entity_kg):
+ """Query results should include otype from Cassandra rows."""
+ kg, ctx = entity_kg
+ from trustgraph.direct.cassandra_kg import QuadResult
+
+ mock_rows = [
+ MagicMock(p='http://ex.org/name', o='Alice',
+ d='', otype='l', dtype='xsd:string', lang='en',
+ s='http://ex.org/Alice'),
+ ]
+ ctx["session"].execute.return_value = mock_rows
+
+ results = kg.get_s('col', 'http://ex.org/Alice')
+
+ assert len(results) == 1
+ assert results[0].otype == 'l'
+ assert results[0].dtype == 'xsd:string'
+ assert results[0].lang == 'en'
+
+ def test_auto_detect_otype_uri(self, entity_kg):
+ """Auto-detect should classify http:// as URI."""
+ kg, ctx = entity_kg
+ ctx["batches"].clear()
+
+ kg.insert(
+ collection='col',
+ s='http://ex.org/s',
+ p='http://ex.org/p',
+ o='http://ex.org/o',
+ )
+
+ batch = ctx["batches"][0]
+ # Check otype in entity rows (position 4)
+ for _, params in batch._adds:
+ if len(params) == 10:
+ assert params[4] == 'u'
+
+ def test_auto_detect_otype_literal(self, entity_kg):
+ """Auto-detect should classify non-http:// as literal."""
+ kg, ctx = entity_kg
+ ctx["batches"].clear()
+
+ kg.insert(
+ collection='col',
+ s='http://ex.org/s',
+ p='http://ex.org/p',
+ o='plain text',
+ )
+
+ batch = ctx["batches"][0]
+ for _, params in batch._adds:
+ if len(params) == 10:
+ assert params[4] == 'l'
diff --git a/tests/unit/test_embeddings/test_document_embeddings_processor.py b/tests/unit/test_embeddings/test_document_embeddings_processor.py
new file mode 100644
index 00000000..9cd93c4f
--- /dev/null
+++ b/tests/unit/test_embeddings/test_document_embeddings_processor.py
@@ -0,0 +1,164 @@
+"""
+Tests for document embeddings processor — single-chunk embedding via batch API.
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock
+
+from trustgraph.embeddings.document_embeddings.embeddings import Processor
+from trustgraph.schema import (
+ Chunk, DocumentEmbeddings, ChunkEmbeddings,
+ EmbeddingsRequest, EmbeddingsResponse, Metadata,
+)
+
+
+@pytest.fixture
+def processor():
+ return Processor(
+ taskgroup=AsyncMock(),
+ id="test-doc-embeddings",
+ )
+
+
+def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
+ user="test", collection="default"):
+ metadata = Metadata(id=doc_id, user=user, collection=collection)
+ value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
+ msg = MagicMock()
+ msg.value.return_value = value
+ return msg
+
+
+class TestDocumentEmbeddingsProcessor:
+
+ @pytest.mark.asyncio
+ async def test_sends_single_text_as_list(self, processor):
+ """Document embeddings should wrap single chunk in a list for the API."""
+ msg = _make_chunk_message("test chunk text")
+
+ mock_request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=[[0.1, 0.2, 0.3]]
+ ))
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(request=mock_request)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ # Should send EmbeddingsRequest with texts=[chunk]
+ mock_request.assert_called_once()
+ req = mock_request.call_args[0][0]
+ assert isinstance(req, EmbeddingsRequest)
+ assert req.texts == ["test chunk text"]
+
+ @pytest.mark.asyncio
+ async def test_extracts_first_vector(self, processor):
+ """Should use vectors[0] from the response."""
+ msg = _make_chunk_message("chunk")
+
+ mock_request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=[[1.0, 2.0, 3.0]]
+ ))
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(request=mock_request)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ result = mock_output.send.call_args[0][0]
+ assert isinstance(result, DocumentEmbeddings)
+ assert len(result.chunks) == 1
+ assert result.chunks[0].vector == [1.0, 2.0, 3.0]
+
+ @pytest.mark.asyncio
+ async def test_empty_vectors_response(self, processor):
+ """Should handle empty vectors response gracefully."""
+ msg = _make_chunk_message("chunk")
+
+ mock_request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=[]
+ ))
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(request=mock_request)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ result = mock_output.send.call_args[0][0]
+ assert result.chunks[0].vector == []
+
+ @pytest.mark.asyncio
+ async def test_chunk_id_is_document_id(self, processor):
+ """ChunkEmbeddings should use document_id as chunk_id."""
+ msg = _make_chunk_message(doc_id="my-doc-42")
+
+ mock_request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=[[0.0]]
+ ))
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(request=mock_request)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ result = mock_output.send.call_args[0][0]
+ assert result.chunks[0].chunk_id == "my-doc-42"
+
+ @pytest.mark.asyncio
+ async def test_metadata_preserved(self, processor):
+ """Output should carry the original metadata."""
+ msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
+
+ mock_request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=[[0.0]]
+ ))
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(request=mock_request)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ result = mock_output.send.call_args[0][0]
+ assert result.metadata.user == "alice"
+ assert result.metadata.collection == "reports"
+ assert result.metadata.id == "d1"
+
+ @pytest.mark.asyncio
+ async def test_error_propagates(self, processor):
+ """Embedding errors should propagate for retry."""
+ msg = _make_chunk_message()
+
+ mock_request = AsyncMock(side_effect=RuntimeError("service down"))
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(request=mock_request)
+ return MagicMock()
+
+ with pytest.raises(RuntimeError, match="service down"):
+ await processor.on_message(msg, MagicMock(), flow)
diff --git a/tests/unit/test_embeddings/test_embeddings_client.py b/tests/unit/test_embeddings/test_embeddings_client.py
new file mode 100644
index 00000000..84305d7a
--- /dev/null
+++ b/tests/unit/test_embeddings/test_embeddings_client.py
@@ -0,0 +1,109 @@
+"""
+Tests for EmbeddingsClient — the client interface for batch embeddings.
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock
+
+from trustgraph.base.embeddings_client import EmbeddingsClient
+from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
+
+
+class TestEmbeddingsClient:
+
+ @pytest.mark.asyncio
+ async def test_embed_sends_request_and_returns_vectors(self):
+ """embed() should send an EmbeddingsRequest and return vectors."""
+ client = EmbeddingsClient.__new__(EmbeddingsClient)
+ client.request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None,
+ vectors=[[0.1, 0.2], [0.3, 0.4]],
+ ))
+
+ result = await client.embed(texts=["hello", "world"])
+
+ assert result == [[0.1, 0.2], [0.3, 0.4]]
+ client.request.assert_called_once()
+ req = client.request.call_args[0][0]
+ assert isinstance(req, EmbeddingsRequest)
+ assert req.texts == ["hello", "world"]
+
+ @pytest.mark.asyncio
+ async def test_embed_single_text(self):
+ """embed() should work with a single text."""
+ client = EmbeddingsClient.__new__(EmbeddingsClient)
+ client.request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None,
+ vectors=[[1.0, 2.0, 3.0]],
+ ))
+
+ result = await client.embed(texts=["single"])
+
+ assert result == [[1.0, 2.0, 3.0]]
+
+ @pytest.mark.asyncio
+ async def test_embed_raises_on_error_response(self):
+ """embed() should raise RuntimeError when response contains an error."""
+ client = EmbeddingsClient.__new__(EmbeddingsClient)
+ client.request = AsyncMock(return_value=EmbeddingsResponse(
+ error=Error(type="embeddings-error", message="model not found"),
+ vectors=[],
+ ))
+
+ with pytest.raises(RuntimeError, match="model not found"):
+ await client.embed(texts=["test"])
+
+ @pytest.mark.asyncio
+ async def test_embed_passes_timeout(self):
+ """embed() should pass timeout to the underlying request."""
+ client = EmbeddingsClient.__new__(EmbeddingsClient)
+ client.request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=[[0.0]],
+ ))
+
+ await client.embed(texts=["test"], timeout=60)
+
+ _, kwargs = client.request.call_args
+ assert kwargs["timeout"] == 60
+
+ @pytest.mark.asyncio
+ async def test_embed_default_timeout(self):
+ """embed() should use 300s default timeout."""
+ client = EmbeddingsClient.__new__(EmbeddingsClient)
+ client.request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=[[0.0]],
+ ))
+
+ await client.embed(texts=["test"])
+
+ _, kwargs = client.request.call_args
+ assert kwargs["timeout"] == 300
+
+ @pytest.mark.asyncio
+ async def test_embed_empty_texts(self):
+ """embed() with empty list should still make the request."""
+ client = EmbeddingsClient.__new__(EmbeddingsClient)
+ client.request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=[],
+ ))
+
+ result = await client.embed(texts=[])
+
+ assert result == []
+
+ @pytest.mark.asyncio
+ async def test_embed_large_batch(self):
+ """embed() should handle large batches."""
+ client = EmbeddingsClient.__new__(EmbeddingsClient)
+ n = 100
+ vectors = [[float(i)] for i in range(n)]
+ client.request = AsyncMock(return_value=EmbeddingsResponse(
+ error=None, vectors=vectors,
+ ))
+
+ texts = [f"text {i}" for i in range(n)]
+ result = await client.embed(texts=texts)
+
+ assert len(result) == n
+ req = client.request.call_args[0][0]
+ assert len(req.texts) == n
diff --git a/tests/unit/test_embeddings/test_embeddings_service_request.py b/tests/unit/test_embeddings/test_embeddings_service_request.py
new file mode 100644
index 00000000..c57fae16
--- /dev/null
+++ b/tests/unit/test_embeddings/test_embeddings_service_request.py
@@ -0,0 +1,135 @@
+"""
+Tests for EmbeddingsService.on_request — the request handler that dispatches
+to on_embeddings and sends responses.
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock
+
+from trustgraph.base import EmbeddingsService
+from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
+from trustgraph.exceptions import TooManyRequests
+
+
+class StubEmbeddingsService(EmbeddingsService):
+ """Minimal concrete implementation for testing on_request."""
+
+ def __init__(self, embed_result=None, embed_error=None):
+ # Skip super().__init__ to avoid taskgroup/registration
+ self.embed_result = embed_result or [[0.1, 0.2]]
+ self.embed_error = embed_error
+
+ async def on_embeddings(self, texts, model=None):
+ if self.embed_error:
+ raise self.embed_error
+ return self.embed_result
+
+
+def _make_msg(texts, msg_id="req-1"):
+ request = EmbeddingsRequest(texts=texts)
+ msg = MagicMock()
+ msg.value.return_value = request
+ msg.properties.return_value = {"id": msg_id}
+ return msg
+
+
+def _make_flow(model="test-model"):
+ mock_response_producer = AsyncMock()
+ mock_flow = MagicMock()
+
+ def flow_callable(name):
+ if name == "model":
+ return model
+ if name == "response":
+ return mock_response_producer
+ return MagicMock()
+
+ flow_callable.producer = {"response": mock_response_producer}
+ return flow_callable, mock_response_producer
+
+
+class TestEmbeddingsServiceOnRequest:
+
+ @pytest.mark.asyncio
+ async def test_successful_request(self):
+ """on_request should call on_embeddings and send response."""
+ service = StubEmbeddingsService(embed_result=[[0.1, 0.2], [0.3, 0.4]])
+ msg = _make_msg(["hello", "world"], msg_id="r1")
+ flow, mock_response = _make_flow(model="my-model")
+
+ await service.on_request(msg, MagicMock(), flow)
+
+ mock_response.send.assert_called_once()
+ resp = mock_response.send.call_args[0][0]
+ assert isinstance(resp, EmbeddingsResponse)
+ assert resp.error is None
+ assert resp.vectors == [[0.1, 0.2], [0.3, 0.4]]
+
+ # Check id is passed through
+ props = mock_response.send.call_args[1]["properties"]
+ assert props["id"] == "r1"
+
+ @pytest.mark.asyncio
+ async def test_passes_model_from_flow(self):
+ """on_request should pass model parameter from flow to on_embeddings."""
+ calls = []
+
+ class TrackingService(EmbeddingsService):
+ def __init__(self):
+ pass
+
+ async def on_embeddings(self, texts, model=None):
+ calls.append({"texts": texts, "model": model})
+ return [[0.0]]
+
+ service = TrackingService()
+ msg = _make_msg(["test"])
+ flow, _ = _make_flow(model="custom-model-v2")
+
+ await service.on_request(msg, MagicMock(), flow)
+
+ assert len(calls) == 1
+ assert calls[0]["model"] == "custom-model-v2"
+ assert calls[0]["texts"] == ["test"]
+
+ @pytest.mark.asyncio
+ async def test_error_sends_error_response(self):
+ """Non-rate-limit errors should send an error response."""
+ service = StubEmbeddingsService(
+ embed_error=ValueError("dimension mismatch")
+ )
+ msg = _make_msg(["test"], msg_id="r2")
+ flow, mock_response = _make_flow()
+
+ await service.on_request(msg, MagicMock(), flow)
+
+ mock_response.send.assert_called_once()
+ resp = mock_response.send.call_args[0][0]
+ assert resp.error is not None
+ assert resp.error.type == "embeddings-error"
+ assert "dimension mismatch" in resp.error.message
+ assert resp.vectors == []
+
+ @pytest.mark.asyncio
+ async def test_rate_limit_propagates(self):
+ """TooManyRequests should propagate (not caught as error response)."""
+ service = StubEmbeddingsService(
+ embed_error=TooManyRequests("rate limited")
+ )
+ msg = _make_msg(["test"])
+ flow, _ = _make_flow()
+
+ with pytest.raises(TooManyRequests):
+ await service.on_request(msg, MagicMock(), flow)
+
+ @pytest.mark.asyncio
+ async def test_message_id_preserved(self):
+ """The request message id should be forwarded in the response properties."""
+ service = StubEmbeddingsService()
+ msg = _make_msg(["test"], msg_id="unique-id-42")
+ flow, mock_response = _make_flow()
+
+ await service.on_request(msg, MagicMock(), flow)
+
+ props = mock_response.send.call_args[1]["properties"]
+ assert props["id"] == "unique-id-42"
diff --git a/tests/unit/test_embeddings/test_graph_embeddings_processor.py b/tests/unit/test_embeddings/test_graph_embeddings_processor.py
new file mode 100644
index 00000000..5d535349
--- /dev/null
+++ b/tests/unit/test_embeddings/test_graph_embeddings_processor.py
@@ -0,0 +1,233 @@
+"""
+Tests for graph embeddings processor — batch embedding of entity contexts.
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock
+
+from trustgraph.embeddings.graph_embeddings.embeddings import Processor
+from trustgraph.schema import (
+ EntityContexts, EntityEmbeddings, GraphEmbeddings,
+ Term, IRI, Metadata,
+)
+
+
+@pytest.fixture
+def processor():
+ return Processor(
+ taskgroup=AsyncMock(),
+ id="test-graph-embeddings",
+ batch_size=3,
+ )
+
+
+def _make_entity_context(name, context, chunk_id="chunk-1"):
+ """Create an entity context for testing."""
+ entity = Term(type=IRI, iri=f"urn:entity:{name}")
+ return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
+
+
+def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
+ metadata = Metadata(id=doc_id, user=user, collection=collection)
+ value = EntityContexts(metadata=metadata, entities=entities)
+ msg = MagicMock()
+ msg.value.return_value = value
+ return msg
+
+
+class TestGraphEmbeddingsInit:
+
+ def test_default_batch_size(self):
+ p = Processor(taskgroup=AsyncMock(), id="test")
+ assert p.batch_size == 5
+
+ def test_custom_batch_size(self):
+ p = Processor(taskgroup=AsyncMock(), id="test", batch_size=20)
+ assert p.batch_size == 20
+
+
+class TestGraphEmbeddingsBatchProcessing:
+
+ @pytest.mark.asyncio
+ async def test_single_batch_call_for_all_entities(self, processor):
+ """All entity contexts should be embedded in a single API call."""
+ entities = [
+ _make_entity_context("Alice", "Alice is a person"),
+ _make_entity_context("Bob", "Bob is a developer"),
+ _make_entity_context("Acme", "Acme is a company"),
+ ]
+ msg = _make_message(entities)
+
+ mock_embed = AsyncMock(return_value=[
+ [0.1, 0.2], [0.3, 0.4], [0.5, 0.6],
+ ])
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(embed=mock_embed)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ # Single batch call with all three texts
+ mock_embed.assert_called_once_with(
+ texts=["Alice is a person", "Bob is a developer", "Acme is a company"]
+ )
+
+ @pytest.mark.asyncio
+ async def test_vectors_paired_with_correct_entities(self, processor):
+ """Each vector should be paired with its corresponding entity."""
+ entities = [
+ _make_entity_context("Alice", "ctx-A", chunk_id="c1"),
+ _make_entity_context("Bob", "ctx-B", chunk_id="c2"),
+ ]
+ msg = _make_message(entities)
+
+ vectors = [[1.0, 2.0], [3.0, 4.0]]
+ mock_embed = AsyncMock(return_value=vectors)
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(embed=mock_embed)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ # With batch_size=3, all 2 entities fit in one output message
+ mock_output.send.assert_called_once()
+ result = mock_output.send.call_args[0][0]
+ assert isinstance(result, GraphEmbeddings)
+ assert len(result.entities) == 2
+ assert result.entities[0].vector == [1.0, 2.0]
+ assert result.entities[0].entity.iri == "urn:entity:Alice"
+ assert result.entities[0].chunk_id == "c1"
+ assert result.entities[1].vector == [3.0, 4.0]
+ assert result.entities[1].entity.iri == "urn:entity:Bob"
+
+ @pytest.mark.asyncio
+ async def test_output_batching(self, processor):
+ """Output should be split into batches of batch_size."""
+ # batch_size=3, 7 entities -> 3 output messages (3+3+1)
+ entities = [
+ _make_entity_context(f"E{i}", f"context {i}")
+ for i in range(7)
+ ]
+ msg = _make_message(entities)
+
+ vectors = [[float(i)] for i in range(7)]
+ mock_embed = AsyncMock(return_value=vectors)
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(embed=mock_embed)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ assert mock_output.send.call_count == 3
+ # First batch has 3 entities
+ batch1 = mock_output.send.call_args_list[0][0][0]
+ assert len(batch1.entities) == 3
+ # Second batch has 3 entities
+ batch2 = mock_output.send.call_args_list[1][0][0]
+ assert len(batch2.entities) == 3
+ # Third batch has 1 entity
+ batch3 = mock_output.send.call_args_list[2][0][0]
+ assert len(batch3.entities) == 1
+
+ @pytest.mark.asyncio
+ async def test_output_batches_preserve_metadata(self, processor):
+ """Each output batch should carry the original metadata."""
+ entities = [
+ _make_entity_context(f"E{i}", f"ctx {i}")
+ for i in range(5)
+ ]
+ msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
+
+ mock_embed = AsyncMock(return_value=[[0.0]] * 5)
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(embed=mock_embed)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ for call in mock_output.send.call_args_list:
+ result = call[0][0]
+ assert result.metadata.id == "doc-42"
+ assert result.metadata.user == "alice"
+ assert result.metadata.collection == "main"
+
+ @pytest.mark.asyncio
+ async def test_single_entity(self, processor):
+ """Single entity should work with one embed call and one output."""
+ entities = [_make_entity_context("Solo", "solo context")]
+ msg = _make_message(entities)
+
+ mock_embed = AsyncMock(return_value=[[1.0, 2.0, 3.0]])
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(embed=mock_embed)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ mock_embed.assert_called_once_with(texts=["solo context"])
+ mock_output.send.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_embed_error_propagates(self, processor):
+ """Embedding service errors should propagate for retry."""
+ entities = [_make_entity_context("E", "ctx")]
+ msg = _make_message(entities)
+
+ mock_embed = AsyncMock(side_effect=RuntimeError("embedding failed"))
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(embed=mock_embed)
+ return MagicMock()
+
+ with pytest.raises(RuntimeError, match="embedding failed"):
+ await processor.on_message(msg, MagicMock(), flow)
+
+ @pytest.mark.asyncio
+ async def test_exact_batch_size(self, processor):
+ """When entity count equals batch_size, exactly one output message."""
+ entities = [
+ _make_entity_context(f"E{i}", f"ctx {i}")
+ for i in range(3) # batch_size=3
+ ]
+ msg = _make_message(entities)
+
+ mock_embed = AsyncMock(return_value=[[0.0]] * 3)
+ mock_output = AsyncMock()
+
+ def flow(name):
+ if name == "embeddings-request":
+ return MagicMock(embed=mock_embed)
+ elif name == "output":
+ return mock_output
+ return MagicMock()
+
+ await processor.on_message(msg, MagicMock(), flow)
+
+ mock_output.send.assert_called_once()
+ assert len(mock_output.send.call_args[0][0].entities) == 3
diff --git a/tests/unit/test_extract/test_streaming_triples/__init__.py b/tests/unit/test_extract/test_streaming_triples/__init__.py
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/tests/unit/test_extract/test_streaming_triples/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py
new file mode 100644
index 00000000..b651b59e
--- /dev/null
+++ b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py
@@ -0,0 +1,407 @@
+"""
+Tests for streaming triple and entity context batching in the definitions
+KG extractor.
+
+Covers: triples batch splitting, entity context batch splitting,
+metadata preservation, provenance, and empty/null filtering.
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock
+
+from trustgraph.extract.kg.definitions.extract import (
+ Processor, default_triples_batch_size, default_entity_batch_size,
+)
+from trustgraph.schema import (
+ Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_processor(triples_batch_size=default_triples_batch_size,
+ entity_batch_size=default_entity_batch_size):
+ proc = Processor.__new__(Processor)
+ proc.triples_batch_size = triples_batch_size
+ proc.entity_batch_size = entity_batch_size
+ return proc
+
+
+def _make_defn(entity, definition):
+ return {"entity": entity, "definition": definition}
+
+
+def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
+ user="user-1", collection="col-1", document_id=""):
+ chunk = Chunk(
+ metadata=Metadata(
+ id=meta_id, root=root, user=user, collection=collection,
+ ),
+ chunk=text.encode("utf-8"),
+ document_id=document_id,
+ )
+ msg = MagicMock()
+ msg.value.return_value = chunk
+ return msg
+
+
+def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
+ mock_triples_pub = AsyncMock()
+ mock_ecs_pub = AsyncMock()
+ mock_prompt_client = AsyncMock()
+ mock_prompt_client.extract_definitions = AsyncMock(
+ return_value=prompt_result
+ )
+
+ def flow(name):
+ if name == "prompt-request":
+ return mock_prompt_client
+ if name == "triples":
+ return mock_triples_pub
+ if name == "entity-contexts":
+ return mock_ecs_pub
+ if name == "llm-model":
+ return llm_model
+ if name == "ontology":
+ return ontology_uri
+ return MagicMock()
+
+ return flow, mock_triples_pub, mock_ecs_pub, mock_prompt_client
+
+
+def _sent_triples(mock_pub):
+ return [call.args[0] for call in mock_pub.send.call_args_list]
+
+
+def _sent_ecs(mock_pub):
+ return [call.args[0] for call in mock_pub.send.call_args_list]
+
+
+def _all_triples_flat(mock_pub):
+ result = []
+ for triples_msg in _sent_triples(mock_pub):
+ result.extend(triples_msg.triples)
+ return result
+
+
+def _all_entities_flat(mock_pub):
+ result = []
+ for ecs_msg in _sent_ecs(mock_pub):
+ result.extend(ecs_msg.entities)
+ return result
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+class TestDefaults:
+
+ def test_default_triples_batch_size(self):
+ assert default_triples_batch_size == 50
+
+ def test_default_entity_batch_size(self):
+ assert default_entity_batch_size == 5
+
+
+class TestTriplesBatching:
+
+ @pytest.mark.asyncio
+ async def test_single_batch_when_under_limit(self):
+ proc = _make_processor(triples_batch_size=100)
+ defs = [_make_defn("Cat", "A feline animal")]
+ flow, triples_pub, _, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert triples_pub.send.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_multiple_triples_batches(self):
+ proc = _make_processor(triples_batch_size=2)
+ defs = [
+ _make_defn("Cat", "A feline"),
+ _make_defn("Dog", "A canine"),
+ ]
+ flow, triples_pub, _, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ # 2 defs → 2 labels + 2 definitions = 4 triples + provenance
+ # With batch_size=2, should produce multiple batches
+ assert triples_pub.send.call_count > 1
+
+ @pytest.mark.asyncio
+ async def test_triples_batch_sizes_within_limit(self):
+ batch_size = 3
+ proc = _make_processor(triples_batch_size=batch_size)
+ defs = [
+ _make_defn("A", "def A"),
+ _make_defn("B", "def B"),
+ _make_defn("C", "def C"),
+ ]
+ flow, triples_pub, _, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ for triples_msg in _sent_triples(triples_pub):
+ assert len(triples_msg.triples) <= batch_size
+
+
+class TestEntityContextBatching:
+
+ @pytest.mark.asyncio
+ async def test_single_entity_batch_when_under_limit(self):
+ proc = _make_processor(entity_batch_size=100)
+ defs = [_make_defn("Cat", "A feline")]
+ flow, _, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ # 1 def → 2 entity contexts (name + definition)
+ assert ecs_pub.send.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_multiple_entity_batches(self):
+ proc = _make_processor(entity_batch_size=2)
+ defs = [
+ _make_defn("Cat", "A feline"),
+ _make_defn("Dog", "A canine"),
+ ]
+ flow, _, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ # 2 defs → 4 entity contexts, batch_size=2 → 2 batches
+ assert ecs_pub.send.call_count == 2
+
+ @pytest.mark.asyncio
+ async def test_entity_batch_sizes_within_limit(self):
+ batch_size = 3
+ proc = _make_processor(entity_batch_size=batch_size)
+ defs = [
+ _make_defn("A", "def A"),
+ _make_defn("B", "def B"),
+ _make_defn("C", "def C"),
+ ]
+ flow, _, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ for ecs_msg in _sent_ecs(ecs_pub):
+ assert len(ecs_msg.entities) <= batch_size
+
+ @pytest.mark.asyncio
+ async def test_entity_contexts_have_name_and_definition(self):
+ """Each definition produces 2 entity contexts: name and definition."""
+ proc = _make_processor(entity_batch_size=100)
+ defs = [_make_defn("Cat", "A feline animal")]
+ flow, _, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ entities = _all_entities_flat(ecs_pub)
+ assert len(entities) == 2
+ contexts = {e.context for e in entities}
+ assert "Cat" in contexts
+ assert "A feline animal" in contexts
+
+
+class TestMetadataPreservation:
+
+ @pytest.mark.asyncio
+ async def test_triples_metadata(self):
+ proc = _make_processor(triples_batch_size=2)
+ defs = [_make_defn("X", "def X")]
+ flow, triples_pub, _, _ = _make_flow(defs)
+ msg = _make_chunk_msg(
+ "text", meta_id="c-1", root="r-1",
+ user="u-1", collection="coll-1",
+ )
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ for triples_msg in _sent_triples(triples_pub):
+ assert triples_msg.metadata.id == "c-1"
+ assert triples_msg.metadata.root == "r-1"
+ assert triples_msg.metadata.user == "u-1"
+ assert triples_msg.metadata.collection == "coll-1"
+
+ @pytest.mark.asyncio
+ async def test_entity_contexts_metadata(self):
+ proc = _make_processor(entity_batch_size=1)
+ defs = [_make_defn("X", "def X")]
+ flow, _, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg(
+ "text", meta_id="c-2", root="r-2",
+ user="u-2", collection="coll-2",
+ )
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ for ecs_msg in _sent_ecs(ecs_pub):
+ assert ecs_msg.metadata.id == "c-2"
+ assert ecs_msg.metadata.root == "r-2"
+
+
+class TestEmptyAndNullFiltering:
+
+ @pytest.mark.asyncio
+ async def test_empty_entity_skipped(self):
+ proc = _make_processor()
+ defs = [
+ _make_defn("", "some definition"),
+ _make_defn("Valid", "a valid definition"),
+ ]
+ flow, triples_pub, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(triples_pub)
+ all_e = _all_entities_flat(ecs_pub)
+ # Only "Valid" should be present
+ entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")}
+ assert any("valid" in iri for iri in entity_iris)
+ assert len(all_e) == 2 # name + definition for "Valid" only
+
+ @pytest.mark.asyncio
+ async def test_empty_definition_skipped(self):
+ proc = _make_processor()
+ defs = [
+ _make_defn("Entity", ""),
+ _make_defn("Good", "good definition"),
+ ]
+ flow, triples_pub, _, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(triples_pub)
+ entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")}
+ assert any("good" in iri for iri in entity_iris)
+ # "Entity" with empty def should have been skipped
+ assert not any("entity" in iri and "good" not in iri for iri in entity_iris)
+
+ @pytest.mark.asyncio
+ async def test_none_fields_skipped(self):
+ proc = _make_processor()
+ defs = [
+ _make_defn(None, "some definition"),
+ _make_defn("Entity", None),
+ ]
+ flow, triples_pub, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert triples_pub.send.call_count == 0
+ assert ecs_pub.send.call_count == 0
+
+ @pytest.mark.asyncio
+ async def test_all_filtered_no_output(self):
+ proc = _make_processor()
+ defs = [_make_defn("", ""), _make_defn(None, None)]
+ flow, triples_pub, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert triples_pub.send.call_count == 0
+ assert ecs_pub.send.call_count == 0
+
+ @pytest.mark.asyncio
+ async def test_empty_prompt_response(self):
+ proc = _make_processor()
+ flow, triples_pub, ecs_pub, _ = _make_flow([])
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert triples_pub.send.call_count == 0
+ assert ecs_pub.send.call_count == 0
+
+
+class TestProvenanceInclusion:
+
+ @pytest.mark.asyncio
+ async def test_provenance_triples_present(self):
+ proc = _make_processor(triples_batch_size=200)
+ defs = [_make_defn("Cat", "A feline")]
+ flow, triples_pub, _, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(triples_pub)
+ # 1 def → 1 label + 1 definition = 2 content triples
+ # Provenance adds more
+ assert len(all_t) > 2
+
+
+class TestErrorHandling:
+
+ @pytest.mark.asyncio
+ async def test_prompt_error_caught(self):
+ proc = _make_processor()
+ flow, triples_pub, ecs_pub, prompt = _make_flow([])
+ prompt.extract_definitions = AsyncMock(
+ side_effect=RuntimeError("LLM error")
+ )
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert triples_pub.send.call_count == 0
+ assert ecs_pub.send.call_count == 0
+
+ @pytest.mark.asyncio
+ async def test_non_list_response_caught(self):
+ proc = _make_processor()
+ flow, triples_pub, ecs_pub, prompt = _make_flow("not a list")
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert triples_pub.send.call_count == 0
+ assert ecs_pub.send.call_count == 0
+
+
+class TestDocumentIdProvenance:
+
+ @pytest.mark.asyncio
+ async def test_document_id_used_for_chunk_id(self):
+ """When document_id is set, entity contexts should use it as chunk_id."""
+ proc = _make_processor(entity_batch_size=100)
+ defs = [_make_defn("Cat", "A feline")]
+ flow, _, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text", document_id="doc-123")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ entities = _all_entities_flat(ecs_pub)
+ for e in entities:
+ assert e.chunk_id == "doc-123"
+
+ @pytest.mark.asyncio
+ async def test_metadata_id_fallback_for_chunk_id(self):
+ """When document_id is empty, metadata.id is used as chunk_id."""
+ proc = _make_processor(entity_batch_size=100)
+ defs = [_make_defn("Cat", "A feline")]
+ flow, _, ecs_pub, _ = _make_flow(defs)
+ msg = _make_chunk_msg("text", meta_id="chunk-42", document_id="")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ entities = _all_entities_flat(ecs_pub)
+ for e in entities:
+ assert e.chunk_id == "chunk-42"
diff --git a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py
new file mode 100644
index 00000000..cf3b1fb0
--- /dev/null
+++ b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py
@@ -0,0 +1,408 @@
+"""
+Tests for streaming triple batching in the relationships KG extractor.
+
+Covers: batch size configuration, output splitting, metadata preservation,
+provenance inclusion, empty/null filtering, and error propagation.
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from trustgraph.extract.kg.relationships.extract import (
+ Processor, default_triples_batch_size,
+)
+from trustgraph.schema import (
+ Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_processor(triples_batch_size=default_triples_batch_size):
+ """Create a Processor without triggering FlowProcessor.__init__."""
+ proc = Processor.__new__(Processor)
+ proc.triples_batch_size = triples_batch_size
+ return proc
+
+
+def _make_rel(subject, predicate, obj, object_entity=True):
+ """Build a relationship dict as returned by the prompt client."""
+ return {
+ "subject": subject,
+ "predicate": predicate,
+ "object": obj,
+ "object-entity": object_entity,
+ }
+
+
+def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
+ user="user-1", collection="col-1", document_id=""):
+ """Build a mock message wrapping a Chunk."""
+ chunk = Chunk(
+ metadata=Metadata(
+ id=meta_id, root=root, user=user, collection=collection,
+ ),
+ chunk=text.encode("utf-8"),
+ document_id=document_id,
+ )
+ msg = MagicMock()
+ msg.value.return_value = chunk
+ return msg
+
+
+def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
+ """Build a mock flow callable that provides prompt client, triples
+ producer, and parameter specs."""
+ mock_triples_pub = AsyncMock()
+ mock_prompt_client = AsyncMock()
+ mock_prompt_client.extract_relationships = AsyncMock(
+ return_value=prompt_result
+ )
+
+ def flow(name):
+ if name == "prompt-request":
+ return mock_prompt_client
+ if name == "triples":
+ return mock_triples_pub
+ if name == "llm-model":
+ return llm_model
+ if name == "ontology":
+ return ontology_uri
+ return MagicMock()
+
+ return flow, mock_triples_pub, mock_prompt_client
+
+
+def _sent_triples(mock_pub):
+ """Collect all Triples objects sent to a mock publisher."""
+ return [call.args[0] for call in mock_pub.send.call_args_list]
+
+
+def _all_triples_flat(mock_pub):
+ """Flatten all batches into one list of Triple objects."""
+ result = []
+ for triples_msg in _sent_triples(mock_pub):
+ result.extend(triples_msg.triples)
+ return result
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+class TestDefaultBatchSize:
+
+ def test_default_is_50(self):
+ assert default_triples_batch_size == 50
+
+ def test_processor_uses_default(self):
+ proc = _make_processor()
+ assert proc.triples_batch_size == 50
+
+
+class TestBatchSplitting:
+
+ @pytest.mark.asyncio
+ async def test_single_batch_when_under_limit(self):
+ """Few triples → single send call."""
+ proc = _make_processor(triples_batch_size=50)
+ rels = [_make_rel("A", "knows", "B")]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("some text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ # One relationship produces: rel triple + 3 labels + provenance
+ # All should fit in one batch of 50
+ assert pub.send.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_multiple_batches_with_small_batch_size(self):
+ """With batch_size=3 and many triples, multiple batches are sent."""
+ proc = _make_processor(triples_batch_size=3)
+ # 2 relationships → 2 rel triples + 6 labels = 8 triples + provenance
+ rels = [
+ _make_rel("A", "knows", "B"),
+ _make_rel("C", "likes", "D"),
+ ]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("some text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ # Should have more than one batch
+ assert pub.send.call_count > 1
+
+ @pytest.mark.asyncio
+ async def test_batch_sizes_respect_limit(self):
+ """No batch should exceed the configured batch size."""
+ batch_size = 3
+ proc = _make_processor(triples_batch_size=batch_size)
+ rels = [
+ _make_rel("A", "knows", "B"),
+ _make_rel("C", "likes", "D"),
+ _make_rel("E", "has", "F"),
+ ]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ for triples_msg in _sent_triples(pub):
+ assert len(triples_msg.triples) <= batch_size
+
+ @pytest.mark.asyncio
+ async def test_all_triples_present_across_batches(self):
+ """Total triples across batches equals expected count."""
+ proc = _make_processor(triples_batch_size=2)
+ # 1 relationship with object-entity=True → 1 rel + 3 labels = 4 triples
+ # + provenance triples
+ rels = [_make_rel("A", "knows", "B", object_entity=True)]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(pub)
+ # At minimum: 1 rel + 3 labels = 4 content triples
+ assert len(all_t) >= 4
+
+ @pytest.mark.asyncio
+ async def test_custom_batch_size(self):
+ """Processor respects custom triples_batch_size parameter."""
+ proc = _make_processor(triples_batch_size=100)
+ assert proc.triples_batch_size == 100
+
+
+class TestMetadataPreservation:
+
+ @pytest.mark.asyncio
+ async def test_metadata_forwarded_to_all_batches(self):
+ """Every batch should carry the original chunk metadata."""
+ proc = _make_processor(triples_batch_size=2)
+ rels = [_make_rel("X", "rel", "Y")]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg(
+ "text", meta_id="c-1", root="r-1",
+ user="u-1", collection="coll-1",
+ )
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ for triples_msg in _sent_triples(pub):
+ assert triples_msg.metadata.id == "c-1"
+ assert triples_msg.metadata.root == "r-1"
+ assert triples_msg.metadata.user == "u-1"
+ assert triples_msg.metadata.collection == "coll-1"
+
+
+class TestRelationshipTriples:
+
+ @pytest.mark.asyncio
+ async def test_entity_object_produces_iri(self):
+ """object-entity=True → object is an IRI, with label triple."""
+ proc = _make_processor(triples_batch_size=200)
+ rels = [_make_rel("Alice", "knows", "Bob", object_entity=True)]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(pub)
+ # Find the relationship triple (not a label)
+ rel_triples = [
+ t for t in all_t
+ if t.o.type == IRI and "bob" in t.o.iri
+ ]
+ assert len(rel_triples) >= 1
+
+ @pytest.mark.asyncio
+ async def test_literal_object_produces_literal(self):
+ """object-entity=False → object is a LITERAL, no label for object."""
+ proc = _make_processor(triples_batch_size=200)
+ rels = [_make_rel("Alice", "age", "30", object_entity=False)]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(pub)
+ # Find the relationship triple with literal object
+ lit_triples = [
+ t for t in all_t
+ if t.o.type == LITERAL and t.o.value == "30"
+ ]
+ assert len(lit_triples) == 1
+
+ @pytest.mark.asyncio
+ async def test_labels_emitted_for_subject_and_predicate(self):
+ """Every relationship should produce label triples for s and p."""
+ proc = _make_processor(triples_batch_size=200)
+ rels = [_make_rel("Alice", "knows", "Bob")]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(pub)
+ label_triples = [
+ t for t in all_t
+ if t.p.type == IRI and "label" in t.p.iri.lower()
+ ]
+ labels = {t.o.value for t in label_triples}
+ assert "Alice" in labels
+ assert "knows" in labels
+ assert "Bob" in labels # object-entity default is True
+
+
+class TestEmptyAndNullFiltering:
+
+ @pytest.mark.asyncio
+ async def test_empty_string_fields_skipped(self):
+ """Relationships with empty string s/p/o are skipped."""
+ proc = _make_processor(triples_batch_size=200)
+ rels = [
+ _make_rel("", "knows", "Bob"),
+ _make_rel("Alice", "", "Bob"),
+ _make_rel("Alice", "knows", ""),
+ _make_rel("Good", "triple", "Here"),
+ ]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(pub)
+ # Only the "Good triple Here" relationship should produce content triples
+ rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri}
+ assert any("good" in iri for iri in rel_iris)
+ assert not any("alice" in iri for iri in rel_iris)
+
+ @pytest.mark.asyncio
+ async def test_none_fields_skipped(self):
+ """Relationships with None s/p/o are skipped."""
+ proc = _make_processor(triples_batch_size=200)
+ rels = [
+ _make_rel(None, "knows", "Bob"),
+ _make_rel("Alice", None, "Bob"),
+ _make_rel("Alice", "knows", None),
+ _make_rel("Valid", "rel", "Here"),
+ ]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(pub)
+ rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri}
+ assert any("valid" in iri for iri in rel_iris)
+ assert not any("alice" in iri for iri in rel_iris)
+
+ @pytest.mark.asyncio
+ async def test_all_filtered_produces_no_output(self):
+ """If all relationships are empty/null, nothing is emitted."""
+ proc = _make_processor(triples_batch_size=200)
+ rels = [
+ _make_rel("", "", ""),
+ _make_rel(None, None, None),
+ ]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert pub.send.call_count == 0
+
+ @pytest.mark.asyncio
+ async def test_empty_prompt_response_produces_no_output(self):
+ """Empty relationship list from prompt → no triples emitted."""
+ proc = _make_processor()
+ flow, pub, _ = _make_flow([])
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert pub.send.call_count == 0
+
+
+class TestProvenanceInclusion:
+
+ @pytest.mark.asyncio
+ async def test_provenance_triples_present(self):
+ """Extracted relationships should include provenance triples."""
+ proc = _make_processor(triples_batch_size=200)
+ rels = [_make_rel("A", "knows", "B")]
+ flow, pub, _ = _make_flow(rels)
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ all_t = _all_triples_flat(pub)
+ # Provenance triples use GRAPH_SOURCE graph context
+ # They contain terms referencing prov: namespace or subgraph URIs
+ # We just check that total count > 4 (1 rel + 3 labels)
+ assert len(all_t) > 4
+
+ @pytest.mark.asyncio
+ async def test_no_provenance_when_no_extracted_triples(self):
+ """Empty relationships → no provenance generated."""
+ proc = _make_processor()
+ flow, pub, _ = _make_flow([_make_rel("", "x", "y")])
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert pub.send.call_count == 0
+
+
+class TestErrorPropagation:
+
+ @pytest.mark.asyncio
+ async def test_prompt_error_is_caught(self):
+ """Errors from the prompt client are caught (logged, not raised)."""
+ proc = _make_processor()
+ flow, pub, prompt = _make_flow([])
+ prompt.extract_relationships = AsyncMock(
+ side_effect=RuntimeError("LLM unavailable")
+ )
+ msg = _make_chunk_msg("text")
+
+ # The outer try/except in on_message catches and logs
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert pub.send.call_count == 0
+
+ @pytest.mark.asyncio
+ async def test_non_list_response_is_caught(self):
+ """Non-list prompt response triggers RuntimeError, caught by handler."""
+ proc = _make_processor()
+ flow, pub, prompt = _make_flow("not a list")
+ msg = _make_chunk_msg("text")
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ assert pub.send.call_count == 0
+
+
+class TestToUri:
+
+ def test_spaces_replaced_with_hyphens(self):
+ proc = _make_processor()
+ uri = proc.to_uri("hello world")
+ assert "hello-world" in uri
+
+ def test_lowercased(self):
+ proc = _make_processor()
+ uri = proc.to_uri("Hello World")
+ assert "hello-world" in uri
+
+ def test_special_chars_encoded(self):
+ proc = _make_processor()
+ # urllib.parse.quote keeps / as safe by default
+ uri = proc.to_uri("a/b")
+ assert "a/b" in uri
+ # Characters like spaces are encoded (handled via replace → hyphen)
+ uri2 = proc.to_uri("hello world")
+ assert " " not in uri2
diff --git a/tests/unit/test_librarian/__init__.py b/tests/unit/test_librarian/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/test_librarian/test_chunked_upload.py b/tests/unit/test_librarian/test_chunked_upload.py
new file mode 100644
index 00000000..2d09bcd4
--- /dev/null
+++ b/tests/unit/test_librarian/test_chunked_upload.py
@@ -0,0 +1,716 @@
+"""
+Tests for librarian chunked upload operations:
+begin_upload, upload_chunk, complete_upload, abort_upload, get_upload_status,
+list_uploads, and stream_document.
+"""
+
+import base64
+import json
+import math
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from trustgraph.librarian.librarian import Librarian, DEFAULT_CHUNK_SIZE
+from trustgraph.exceptions import RequestError
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_librarian(min_chunk_size=1):
+ """Create a Librarian with mocked blob_store and table_store."""
+ lib = Librarian.__new__(Librarian)
+ lib.blob_store = MagicMock()
+ lib.table_store = AsyncMock()
+ lib.load_document = AsyncMock()
+ lib.min_chunk_size = min_chunk_size
+ return lib
+
+
+def _make_doc_metadata(
+ doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
+):
+ meta = MagicMock()
+ meta.id = doc_id
+ meta.kind = kind
+ meta.user = user
+ meta.title = title
+ meta.time = 1700000000
+ meta.comments = ""
+ meta.tags = []
+ return meta
+
+
+def _make_begin_request(
+ doc_id="doc-1", kind="application/pdf", user="alice",
+ total_size=10_000_000, chunk_size=0
+):
+ req = MagicMock()
+ req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
+ req.total_size = total_size
+ req.chunk_size = chunk_size
+ return req
+
+
+def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
+ req = MagicMock()
+ req.upload_id = upload_id
+ req.chunk_index = chunk_index
+ req.user = user
+ req.content = base64.b64encode(content)
+ return req
+
+
+def _make_session(
+ user="alice", total_chunks=5, chunk_size=2_000_000,
+ total_size=10_000_000, chunks_received=None, object_id="obj-1",
+ s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
+):
+ if chunks_received is None:
+ chunks_received = {}
+ if document_metadata is None:
+ document_metadata = json.dumps({
+ "id": document_id, "kind": "application/pdf",
+ "user": user, "title": "Test", "time": 1700000000,
+ "comments": "", "tags": [],
+ })
+ return {
+ "user": user,
+ "total_chunks": total_chunks,
+ "chunk_size": chunk_size,
+ "total_size": total_size,
+ "chunks_received": chunks_received,
+ "object_id": object_id,
+ "s3_upload_id": s3_upload_id,
+ "document_metadata": document_metadata,
+ "document_id": document_id,
+ }
+
+
+# ---------------------------------------------------------------------------
+# begin_upload
+# ---------------------------------------------------------------------------
+
+class TestBeginUpload:
+
+ @pytest.mark.asyncio
+ async def test_creates_session(self):
+ lib = _make_librarian()
+ lib.table_store.document_exists.return_value = False
+ lib.blob_store.create_multipart_upload.return_value = "s3-upload-id"
+
+ req = _make_begin_request(total_size=10_000_000)
+ resp = await lib.begin_upload(req)
+
+ assert resp.error is None
+ assert resp.upload_id is not None
+ assert resp.total_chunks == math.ceil(10_000_000 / DEFAULT_CHUNK_SIZE)
+ assert resp.chunk_size == DEFAULT_CHUNK_SIZE
+
+ @pytest.mark.asyncio
+ async def test_custom_chunk_size(self):
+ lib = _make_librarian()
+ lib.table_store.document_exists.return_value = False
+ lib.blob_store.create_multipart_upload.return_value = "s3-id"
+
+ req = _make_begin_request(total_size=10_000, chunk_size=3000)
+ resp = await lib.begin_upload(req)
+
+ assert resp.chunk_size == 3000
+ assert resp.total_chunks == math.ceil(10_000 / 3000)
+
+ @pytest.mark.asyncio
+ async def test_rejects_invalid_kind(self):
+ lib = _make_librarian()
+ req = _make_begin_request(kind="image/png")
+
+ with pytest.raises(RequestError, match="Invalid document kind"):
+ await lib.begin_upload(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_duplicate_document(self):
+ lib = _make_librarian()
+ lib.table_store.document_exists.return_value = True
+
+ req = _make_begin_request()
+ with pytest.raises(RequestError, match="already exists"):
+ await lib.begin_upload(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_zero_size(self):
+ lib = _make_librarian()
+ lib.table_store.document_exists.return_value = False
+
+ req = _make_begin_request(total_size=0)
+ with pytest.raises(RequestError, match="positive"):
+ await lib.begin_upload(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_chunk_below_minimum(self):
+ lib = _make_librarian(min_chunk_size=1024)
+ lib.table_store.document_exists.return_value = False
+
+ req = _make_begin_request(total_size=10_000, chunk_size=512)
+ with pytest.raises(RequestError, match="below minimum"):
+ await lib.begin_upload(req)
+
+ @pytest.mark.asyncio
+ async def test_calls_s3_create_multipart(self):
+ lib = _make_librarian()
+ lib.table_store.document_exists.return_value = False
+ lib.blob_store.create_multipart_upload.return_value = "s3-id"
+
+ req = _make_begin_request(kind="application/pdf")
+ await lib.begin_upload(req)
+
+ lib.blob_store.create_multipart_upload.assert_called_once()
+ # create_multipart_upload(object_id, kind) — positional args
+ args = lib.blob_store.create_multipart_upload.call_args[0]
+ assert args[1] == "application/pdf"
+
+ @pytest.mark.asyncio
+ async def test_stores_session_in_cassandra(self):
+ lib = _make_librarian()
+ lib.table_store.document_exists.return_value = False
+ lib.blob_store.create_multipart_upload.return_value = "s3-id"
+
+ req = _make_begin_request(total_size=5_000_000)
+ resp = await lib.begin_upload(req)
+
+ lib.table_store.create_upload_session.assert_called_once()
+ kwargs = lib.table_store.create_upload_session.call_args[1]
+ assert kwargs["upload_id"] == resp.upload_id
+ assert kwargs["total_size"] == 5_000_000
+ assert kwargs["total_chunks"] == resp.total_chunks
+
+ @pytest.mark.asyncio
+ async def test_accepts_text_plain(self):
+ lib = _make_librarian()
+ lib.table_store.document_exists.return_value = False
+ lib.blob_store.create_multipart_upload.return_value = "s3-id"
+
+ req = _make_begin_request(kind="text/plain", total_size=1000)
+ resp = await lib.begin_upload(req)
+ assert resp.error is None
+
+
+# ---------------------------------------------------------------------------
+# upload_chunk
+# ---------------------------------------------------------------------------
+
+class TestUploadChunk:
+
+ @pytest.mark.asyncio
+ async def test_successful_chunk_upload(self):
+ lib = _make_librarian()
+ session = _make_session(total_chunks=5, chunks_received={})
+ lib.table_store.get_upload_session.return_value = session
+ lib.blob_store.upload_part.return_value = "etag-1"
+
+ req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data")
+ resp = await lib.upload_chunk(req)
+
+ assert resp.error is None
+ assert resp.chunk_index == 0
+ assert resp.total_chunks == 5
+ # The chunk is added to the dict (len=1), then +1 applied => 2
+ assert resp.chunks_received == 2
+
+ @pytest.mark.asyncio
+ async def test_s3_part_number_is_1_indexed(self):
+ lib = _make_librarian()
+ session = _make_session()
+ lib.table_store.get_upload_session.return_value = session
+ lib.blob_store.upload_part.return_value = "etag"
+
+ req = _make_upload_chunk_request(chunk_index=0)
+ await lib.upload_chunk(req)
+
+ kwargs = lib.blob_store.upload_part.call_args[1]
+ assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part
+
+ @pytest.mark.asyncio
+ async def test_chunk_index_3_becomes_part_4(self):
+ lib = _make_librarian()
+ session = _make_session()
+ lib.table_store.get_upload_session.return_value = session
+ lib.blob_store.upload_part.return_value = "etag"
+
+ req = _make_upload_chunk_request(chunk_index=3)
+ await lib.upload_chunk(req)
+
+ kwargs = lib.blob_store.upload_part.call_args[1]
+ assert kwargs["part_number"] == 4
+
+ @pytest.mark.asyncio
+ async def test_rejects_expired_session(self):
+ lib = _make_librarian()
+ lib.table_store.get_upload_session.return_value = None
+
+ req = _make_upload_chunk_request()
+ with pytest.raises(RequestError, match="not found"):
+ await lib.upload_chunk(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_wrong_user(self):
+ lib = _make_librarian()
+ session = _make_session(user="alice")
+ lib.table_store.get_upload_session.return_value = session
+
+ req = _make_upload_chunk_request(user="bob")
+ with pytest.raises(RequestError, match="Not authorized"):
+ await lib.upload_chunk(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_negative_chunk_index(self):
+ lib = _make_librarian()
+ session = _make_session(total_chunks=5)
+ lib.table_store.get_upload_session.return_value = session
+
+ req = _make_upload_chunk_request(chunk_index=-1)
+ with pytest.raises(RequestError, match="Invalid chunk index"):
+ await lib.upload_chunk(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_out_of_range_chunk_index(self):
+ lib = _make_librarian()
+ session = _make_session(total_chunks=5)
+ lib.table_store.get_upload_session.return_value = session
+
+ req = _make_upload_chunk_request(chunk_index=5)
+ with pytest.raises(RequestError, match="Invalid chunk index"):
+ await lib.upload_chunk(req)
+
+ @pytest.mark.asyncio
+ async def test_progress_tracking(self):
+ lib = _make_librarian()
+ session = _make_session(
+ total_chunks=4, chunk_size=1000, total_size=3500,
+ chunks_received={0: "e1", 1: "e2"},
+ )
+ lib.table_store.get_upload_session.return_value = session
+ lib.blob_store.upload_part.return_value = "e3"
+
+ req = _make_upload_chunk_request(chunk_index=2)
+ resp = await lib.upload_chunk(req)
+
+ # Dict gets chunk 2 added (len=3), then +1 => 4
+ assert resp.chunks_received == 4
+ assert resp.total_chunks == 4
+ assert resp.total_bytes == 3500
+
+ @pytest.mark.asyncio
+ async def test_bytes_capped_at_total_size(self):
+ """bytes_received should not exceed total_size for the final chunk."""
+ lib = _make_librarian()
+ session = _make_session(
+ total_chunks=2, chunk_size=3000, total_size=5000,
+ chunks_received={0: "e1"},
+ )
+ lib.table_store.get_upload_session.return_value = session
+ lib.blob_store.upload_part.return_value = "e2"
+
+ req = _make_upload_chunk_request(chunk_index=1)
+ resp = await lib.upload_chunk(req)
+
+ # 3 chunks × 3000 = 9000 > 5000, so capped
+ assert resp.bytes_received <= 5000
+
+ @pytest.mark.asyncio
+ async def test_base64_decodes_content(self):
+ lib = _make_librarian()
+ session = _make_session()
+ lib.table_store.get_upload_session.return_value = session
+ lib.blob_store.upload_part.return_value = "etag"
+
+ raw = b"hello world binary data"
+ req = _make_upload_chunk_request(content=raw)
+ await lib.upload_chunk(req)
+
+ kwargs = lib.blob_store.upload_part.call_args[1]
+ assert kwargs["data"] == raw
+
+
+# ---------------------------------------------------------------------------
+# complete_upload
+# ---------------------------------------------------------------------------
+
+class TestCompleteUpload:
+
+ @pytest.mark.asyncio
+ async def test_successful_completion(self):
+ lib = _make_librarian()
+ session = _make_session(
+ total_chunks=3,
+ chunks_received={0: "e1", 1: "e2", 2: "e3"},
+ )
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "alice"
+
+ resp = await lib.complete_upload(req)
+
+ assert resp.error is None
+ assert resp.document_id == "doc-1"
+ lib.blob_store.complete_multipart_upload.assert_called_once()
+ lib.table_store.add_document.assert_called_once()
+ lib.table_store.delete_upload_session.assert_called_once_with("up-1")
+
+ @pytest.mark.asyncio
+ async def test_parts_sorted_by_index(self):
+ lib = _make_librarian()
+ # Chunks received out of order
+ session = _make_session(
+ total_chunks=3,
+ chunks_received={2: "e3", 0: "e1", 1: "e2"},
+ )
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "alice"
+
+ await lib.complete_upload(req)
+
+ parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"]
+ part_numbers = [p[0] for p in parts]
+ assert part_numbers == [1, 2, 3] # Sorted, 1-indexed
+
+ @pytest.mark.asyncio
+ async def test_rejects_missing_chunks(self):
+ lib = _make_librarian()
+ session = _make_session(
+ total_chunks=3,
+ chunks_received={0: "e1", 2: "e3"}, # chunk 1 missing
+ )
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "alice"
+
+ with pytest.raises(RequestError, match="Missing chunks"):
+ await lib.complete_upload(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_expired_session(self):
+ lib = _make_librarian()
+ lib.table_store.get_upload_session.return_value = None
+
+ req = MagicMock()
+ req.upload_id = "up-gone"
+ req.user = "alice"
+
+ with pytest.raises(RequestError, match="not found"):
+ await lib.complete_upload(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_wrong_user(self):
+ lib = _make_librarian()
+ session = _make_session(user="alice")
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "bob"
+
+ with pytest.raises(RequestError, match="Not authorized"):
+ await lib.complete_upload(req)
+
+
+# ---------------------------------------------------------------------------
+# abort_upload
+# ---------------------------------------------------------------------------
+
+class TestAbortUpload:
+
+ @pytest.mark.asyncio
+ async def test_aborts_and_cleans_up(self):
+ lib = _make_librarian()
+ session = _make_session()
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "alice"
+
+ resp = await lib.abort_upload(req)
+
+ assert resp.error is None
+ lib.blob_store.abort_multipart_upload.assert_called_once_with(
+ object_id="obj-1", upload_id="s3-up-1"
+ )
+ lib.table_store.delete_upload_session.assert_called_once_with("up-1")
+
+ @pytest.mark.asyncio
+ async def test_rejects_expired_session(self):
+ lib = _make_librarian()
+ lib.table_store.get_upload_session.return_value = None
+
+ req = MagicMock()
+ req.upload_id = "up-gone"
+ req.user = "alice"
+
+ with pytest.raises(RequestError, match="not found"):
+ await lib.abort_upload(req)
+
+ @pytest.mark.asyncio
+ async def test_rejects_wrong_user(self):
+ lib = _make_librarian()
+ session = _make_session(user="alice")
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "bob"
+
+ with pytest.raises(RequestError, match="Not authorized"):
+ await lib.abort_upload(req)
+
+
+# ---------------------------------------------------------------------------
+# get_upload_status
+# ---------------------------------------------------------------------------
+
+class TestGetUploadStatus:
+
+ @pytest.mark.asyncio
+ async def test_in_progress_status(self):
+ lib = _make_librarian()
+ session = _make_session(
+ total_chunks=5, chunk_size=2000, total_size=10_000,
+ chunks_received={0: "e1", 2: "e3", 4: "e5"},
+ )
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "alice"
+
+ resp = await lib.get_upload_status(req)
+
+ assert resp.upload_state == "in-progress"
+ assert resp.chunks_received == 3
+ assert resp.total_chunks == 5
+ assert sorted(resp.received_chunks) == [0, 2, 4]
+ assert sorted(resp.missing_chunks) == [1, 3]
+ assert resp.total_bytes == 10_000
+
+ @pytest.mark.asyncio
+ async def test_expired_session(self):
+ lib = _make_librarian()
+ lib.table_store.get_upload_session.return_value = None
+
+ req = MagicMock()
+ req.upload_id = "up-expired"
+ req.user = "alice"
+
+ resp = await lib.get_upload_status(req)
+
+ assert resp.upload_state == "expired"
+
+ @pytest.mark.asyncio
+ async def test_all_chunks_received(self):
+ lib = _make_librarian()
+ session = _make_session(
+ total_chunks=3, chunk_size=1000, total_size=2500,
+ chunks_received={0: "e1", 1: "e2", 2: "e3"},
+ )
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "alice"
+
+ resp = await lib.get_upload_status(req)
+
+ assert resp.missing_chunks == []
+ assert resp.chunks_received == 3
+ # 3 * 1000 = 3000 > 2500, so capped
+ assert resp.bytes_received <= 2500
+
+ @pytest.mark.asyncio
+ async def test_rejects_wrong_user(self):
+ lib = _make_librarian()
+ session = _make_session(user="alice")
+ lib.table_store.get_upload_session.return_value = session
+
+ req = MagicMock()
+ req.upload_id = "up-1"
+ req.user = "bob"
+
+ with pytest.raises(RequestError, match="Not authorized"):
+ await lib.get_upload_status(req)
+
+
+# ---------------------------------------------------------------------------
+# stream_document
+# ---------------------------------------------------------------------------
+
+class TestStreamDocument:
+
+ @pytest.mark.asyncio
+ async def test_streams_chunks_with_progress(self):
+ lib = _make_librarian()
+ lib.table_store.get_document_object_id.return_value = "obj-1"
+ lib.blob_store.get_size = AsyncMock(return_value=5000)
+ lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
+
+ req = MagicMock()
+ req.user = "alice"
+ req.document_id = "doc-1"
+ req.chunk_size = 2000
+
+ chunks = []
+ async for resp in lib.stream_document(req):
+ chunks.append(resp)
+
+ assert len(chunks) == 3 # ceil(5000/2000)
+ assert chunks[0].chunk_index == 0
+ assert chunks[0].total_chunks == 3
+ assert chunks[0].is_final is False
+ assert chunks[-1].is_final is True
+ assert chunks[-1].chunk_index == 2
+
+ @pytest.mark.asyncio
+ async def test_single_chunk_document(self):
+ lib = _make_librarian()
+ lib.table_store.get_document_object_id.return_value = "obj-1"
+ lib.blob_store.get_size = AsyncMock(return_value=500)
+ lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
+
+ req = MagicMock()
+ req.user = "alice"
+ req.document_id = "doc-1"
+ req.chunk_size = 2000
+
+ chunks = []
+ async for resp in lib.stream_document(req):
+ chunks.append(resp)
+
+ assert len(chunks) == 1
+ assert chunks[0].is_final is True
+ assert chunks[0].bytes_received == 500
+ assert chunks[0].total_bytes == 500
+
+ @pytest.mark.asyncio
+ async def test_byte_ranges_correct(self):
+ lib = _make_librarian()
+ lib.table_store.get_document_object_id.return_value = "obj-1"
+ lib.blob_store.get_size = AsyncMock(return_value=5000)
+ lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
+
+ req = MagicMock()
+ req.user = "alice"
+ req.document_id = "doc-1"
+ req.chunk_size = 2000
+
+ chunks = []
+ async for resp in lib.stream_document(req):
+ chunks.append(resp)
+
+ # Verify the byte ranges passed to get_range
+ calls = lib.blob_store.get_range.call_args_list
+ assert calls[0][0] == ("obj-1", 0, 2000)
+ assert calls[1][0] == ("obj-1", 2000, 2000)
+ assert calls[2][0] == ("obj-1", 4000, 1000) # Last chunk: 5000-4000
+
+ @pytest.mark.asyncio
+ async def test_default_chunk_size(self):
+ lib = _make_librarian()
+ lib.table_store.get_document_object_id.return_value = "obj-1"
+ lib.blob_store.get_size = AsyncMock(return_value=2_000_000)
+ lib.blob_store.get_range = AsyncMock(return_value=b"x")
+
+ req = MagicMock()
+ req.user = "alice"
+ req.document_id = "doc-1"
+ req.chunk_size = 0 # Should use default 1MB
+
+ chunks = []
+ async for resp in lib.stream_document(req):
+ chunks.append(resp)
+
+ assert len(chunks) == 2 # ceil(2MB / 1MB)
+
+ @pytest.mark.asyncio
+ async def test_content_is_base64_encoded(self):
+ lib = _make_librarian()
+ lib.table_store.get_document_object_id.return_value = "obj-1"
+ lib.blob_store.get_size = AsyncMock(return_value=100)
+ raw = b"hello world"
+ lib.blob_store.get_range = AsyncMock(return_value=raw)
+
+ req = MagicMock()
+ req.user = "alice"
+ req.document_id = "doc-1"
+ req.chunk_size = 1000
+
+ chunks = []
+ async for resp in lib.stream_document(req):
+ chunks.append(resp)
+
+ assert chunks[0].content == base64.b64encode(raw)
+
+ @pytest.mark.asyncio
+ async def test_rejects_chunk_below_minimum(self):
+ lib = _make_librarian(min_chunk_size=1024)
+ lib.table_store.get_document_object_id.return_value = "obj-1"
+ lib.blob_store.get_size = AsyncMock(return_value=5000)
+
+ req = MagicMock()
+ req.user = "alice"
+ req.document_id = "doc-1"
+ req.chunk_size = 512
+
+ with pytest.raises(RequestError, match="below minimum"):
+ async for _ in lib.stream_document(req):
+ pass
+
+
+# ---------------------------------------------------------------------------
+# list_uploads
+# ---------------------------------------------------------------------------
+
+class TestListUploads:
+
+ @pytest.mark.asyncio
+ async def test_returns_sessions(self):
+ lib = _make_librarian()
+ lib.table_store.list_upload_sessions.return_value = [
+ {
+ "upload_id": "up-1",
+ "document_id": "doc-1",
+ "document_metadata": '{"id":"doc-1"}',
+ "total_size": 10000,
+ "chunk_size": 2000,
+ "total_chunks": 5,
+ "chunks_received": {0: "e1", 1: "e2"},
+ "created_at": "2024-01-01",
+ },
+ ]
+
+ req = MagicMock()
+ req.user = "alice"
+
+ resp = await lib.list_uploads(req)
+
+ assert resp.error is None
+ assert len(resp.upload_sessions) == 1
+ assert resp.upload_sessions[0].upload_id == "up-1"
+ assert resp.upload_sessions[0].total_chunks == 5
+
+ @pytest.mark.asyncio
+ async def test_empty_uploads(self):
+ lib = _make_librarian()
+ lib.table_store.list_upload_sessions.return_value = []
+
+ req = MagicMock()
+ req.user = "alice"
+
+ resp = await lib.list_uploads(req)
+
+ assert resp.upload_sessions == []
diff --git a/tests/unit/test_provenance/__init__.py b/tests/unit/test_provenance/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/test_provenance/test_agent_provenance.py b/tests/unit/test_provenance/test_agent_provenance.py
new file mode 100644
index 00000000..9377fe19
--- /dev/null
+++ b/tests/unit/test_provenance/test_agent_provenance.py
@@ -0,0 +1,324 @@
+"""
+Tests for agent provenance triple builder functions.
+"""
+
+import json
+import pytest
+
+from trustgraph.schema import Triple, Term, IRI, LITERAL
+
+from trustgraph.provenance.agent import (
+ agent_session_triples,
+ agent_iteration_triples,
+ agent_final_triples,
+)
+
+from trustgraph.provenance.namespaces import (
+ RDF_TYPE, RDFS_LABEL,
+ PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
+ TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
+ TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
+ TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
+ TG_AGENT_QUESTION,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def find_triple(triples, predicate, subject=None):
+ for t in triples:
+ if t.p.iri == predicate:
+ if subject is None or t.s.iri == subject:
+ return t
+ return None
+
+
+def find_triples(triples, predicate, subject=None):
+ return [
+ t for t in triples
+ if t.p.iri == predicate and (subject is None or t.s.iri == subject)
+ ]
+
+
+def has_type(triples, subject, rdf_type):
+ for t in triples:
+ if (t.s.iri == subject and t.p.iri == RDF_TYPE
+ and t.o.type == IRI and t.o.iri == rdf_type):
+ return True
+ return False
+
+
+# ---------------------------------------------------------------------------
+# agent_session_triples
+# ---------------------------------------------------------------------------
+
+class TestAgentSessionTriples:
+
+ SESSION_URI = "urn:trustgraph:agent:test-session"
+
+ def test_session_types(self):
+ triples = agent_session_triples(
+ self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
+ )
+ assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY)
+ assert has_type(triples, self.SESSION_URI, TG_QUESTION)
+ assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION)
+
+ def test_session_query_text(self):
+ triples = agent_session_triples(
+ self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
+ )
+ query = find_triple(triples, TG_QUERY, self.SESSION_URI)
+ assert query is not None
+ assert query.o.value == "What is X?"
+
+ def test_session_timestamp(self):
+ triples = agent_session_triples(
+ self.SESSION_URI, "Q", "2024-06-15T10:00:00Z"
+ )
+ ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI)
+ assert ts is not None
+ assert ts.o.value == "2024-06-15T10:00:00Z"
+
+ def test_session_default_timestamp(self):
+ triples = agent_session_triples(self.SESSION_URI, "Q")
+ ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI)
+ assert ts is not None
+ assert len(ts.o.value) > 0
+
+ def test_session_label(self):
+ triples = agent_session_triples(
+ self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
+ )
+ label = find_triple(triples, RDFS_LABEL, self.SESSION_URI)
+ assert label is not None
+ assert label.o.value == "Agent Question"
+
+ def test_session_triple_count(self):
+ triples = agent_session_triples(
+ self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
+ )
+ assert len(triples) == 6
+
+
+# ---------------------------------------------------------------------------
+# agent_iteration_triples
+# ---------------------------------------------------------------------------
+
+class TestAgentIterationTriples:
+
+ ITER_URI = "urn:trustgraph:agent:test-session/i1"
+ PARENT_URI = "urn:trustgraph:agent:test-session"
+
+ def test_iteration_types(self):
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ thought="thinking", action="search", observation="found it",
+ )
+ assert has_type(triples, self.ITER_URI, PROV_ENTITY)
+ assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
+
+ def test_iteration_derived_from_parent(self):
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ action="search",
+ )
+ derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
+ assert derived is not None
+ assert derived.o.iri == self.PARENT_URI
+
+ def test_iteration_label_includes_action(self):
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ action="graph-rag-query",
+ )
+ label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
+ assert label is not None
+ assert "graph-rag-query" in label.o.value
+
+ def test_iteration_thought_inline(self):
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ thought="I need to search for info",
+ action="search",
+ )
+ thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
+ assert thought is not None
+ assert thought.o.value == "I need to search for info"
+
+ def test_iteration_thought_document_preferred(self):
+ """When thought_document_id is provided, inline thought is not stored."""
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ thought="inline thought",
+ action="search",
+ thought_document_id="urn:doc:thought-1",
+ )
+ thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI)
+ assert thought_doc is not None
+ assert thought_doc.o.iri == "urn:doc:thought-1"
+ thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI)
+ assert thought_inline is None
+
+ def test_iteration_observation_inline(self):
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ action="search",
+ observation="Found 3 results",
+ )
+ obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
+ assert obs is not None
+ assert obs.o.value == "Found 3 results"
+
+ def test_iteration_observation_document_preferred(self):
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ action="search",
+ observation="inline obs",
+ observation_document_id="urn:doc:obs-1",
+ )
+ obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI)
+ assert obs_doc is not None
+ assert obs_doc.o.iri == "urn:doc:obs-1"
+ obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
+ assert obs_inline is None
+
+ def test_iteration_action_recorded(self):
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ action="graph-rag-query",
+ )
+ action = find_triple(triples, TG_ACTION, self.ITER_URI)
+ assert action is not None
+ assert action.o.value == "graph-rag-query"
+
+ def test_iteration_arguments_json_encoded(self):
+ args = {"query": "test query", "limit": 10}
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ action="search",
+ arguments=args,
+ )
+ arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
+ assert arguments is not None
+ parsed = json.loads(arguments.o.value)
+ assert parsed == args
+
+ def test_iteration_default_arguments_empty_dict(self):
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ action="search",
+ )
+ arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
+ assert arguments is not None
+ parsed = json.loads(arguments.o.value)
+ assert parsed == {}
+
+ def test_iteration_no_thought_or_observation(self):
+ """Minimal iteration with just action — no thought or observation triples."""
+ triples = agent_iteration_triples(
+ self.ITER_URI, self.PARENT_URI,
+ action="noop",
+ )
+ thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
+ obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
+ assert thought is None
+ assert obs is None
+
+ def test_iteration_chaining(self):
+ """Second iteration derives from first iteration, not session."""
+ iter1_uri = "urn:trustgraph:agent:sess/i1"
+ iter2_uri = "urn:trustgraph:agent:sess/i2"
+
+ triples1 = agent_iteration_triples(
+ iter1_uri, self.PARENT_URI, action="step1",
+ )
+ triples2 = agent_iteration_triples(
+ iter2_uri, iter1_uri, action="step2",
+ )
+
+ derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
+ assert derived1.o.iri == self.PARENT_URI
+
+ derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
+ assert derived2.o.iri == iter1_uri
+
+
+# ---------------------------------------------------------------------------
+# agent_final_triples
+# ---------------------------------------------------------------------------
+
+class TestAgentFinalTriples:
+
+ FINAL_URI = "urn:trustgraph:agent:test-session/final"
+ PARENT_URI = "urn:trustgraph:agent:test-session/i3"
+
+ def test_final_types(self):
+ triples = agent_final_triples(
+ self.FINAL_URI, self.PARENT_URI, answer="42"
+ )
+ assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
+ assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
+
+ def test_final_derived_from_parent(self):
+ triples = agent_final_triples(
+ self.FINAL_URI, self.PARENT_URI, answer="42"
+ )
+ derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
+ assert derived is not None
+ assert derived.o.iri == self.PARENT_URI
+
+ def test_final_label(self):
+ triples = agent_final_triples(
+ self.FINAL_URI, self.PARENT_URI, answer="42"
+ )
+ label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
+ assert label is not None
+ assert label.o.value == "Conclusion"
+
+ def test_final_inline_answer(self):
+ triples = agent_final_triples(
+ self.FINAL_URI, self.PARENT_URI, answer="The answer is 42"
+ )
+ answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
+ assert answer is not None
+ assert answer.o.value == "The answer is 42"
+
+ def test_final_document_reference(self):
+ triples = agent_final_triples(
+ self.FINAL_URI, self.PARENT_URI,
+ document_id="urn:trustgraph:agent:sess/answer",
+ )
+ doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
+ assert doc is not None
+ assert doc.o.type == IRI
+ assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
+
+ def test_final_document_takes_precedence(self):
+ triples = agent_final_triples(
+ self.FINAL_URI, self.PARENT_URI,
+ answer="inline",
+ document_id="urn:doc:123",
+ )
+ doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
+ assert doc is not None
+ answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
+ assert answer is None
+
+ def test_final_no_answer_or_document(self):
+ triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI)
+ answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
+ doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
+ assert answer is None
+ assert doc is None
+
+ def test_final_derives_from_session_when_no_iterations(self):
+ """When agent answers immediately, final derives from session."""
+ session_uri = "urn:trustgraph:agent:test-session"
+ triples = agent_final_triples(
+ self.FINAL_URI, session_uri, answer="direct answer"
+ )
+ derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
+ assert derived.o.iri == session_uri
diff --git a/tests/unit/test_provenance/test_explainability.py b/tests/unit/test_provenance/test_explainability.py
new file mode 100644
index 00000000..1f27cc61
--- /dev/null
+++ b/tests/unit/test_provenance/test_explainability.py
@@ -0,0 +1,563 @@
+"""
+Tests for the explainability API (entity parsing, wire format conversion,
+and ExplainabilityClient).
+"""
+
+import pytest
+from unittest.mock import MagicMock, patch
+
+from trustgraph.api.explainability import (
+ EdgeSelection,
+ ExplainEntity,
+ Question,
+ Exploration,
+ Focus,
+ Synthesis,
+ Analysis,
+ Conclusion,
+ parse_edge_selection_triples,
+ extract_term_value,
+ wire_triples_to_tuples,
+ ExplainabilityClient,
+ TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
+ TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT,
+ TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
+ TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
+ TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
+ TG_ANALYSIS, TG_CONCLUSION,
+ TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
+ PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
+ RDF_TYPE, RDFS_LABEL,
+)
+
+
+# ---------------------------------------------------------------------------
+# Entity from_triples parsing
+# ---------------------------------------------------------------------------
+
+class TestExplainEntityFromTriples:
+ """Test ExplainEntity.from_triples dispatches to correct subclass."""
+
+ def test_graphrag_question(self):
+ triples = [
+ ("urn:q:1", RDF_TYPE, TG_QUESTION),
+ ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
+ ("urn:q:1", TG_QUERY, "What is AI?"),
+ ("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"),
+ ]
+ entity = ExplainEntity.from_triples("urn:q:1", triples)
+ assert isinstance(entity, Question)
+ assert entity.query == "What is AI?"
+ assert entity.timestamp == "2024-01-01T00:00:00Z"
+ assert entity.question_type == "graph-rag"
+
+ def test_docrag_question(self):
+ triples = [
+ ("urn:q:2", RDF_TYPE, TG_QUESTION),
+ ("urn:q:2", RDF_TYPE, TG_DOC_RAG_QUESTION),
+ ("urn:q:2", TG_QUERY, "Find info"),
+ ]
+ entity = ExplainEntity.from_triples("urn:q:2", triples)
+ assert isinstance(entity, Question)
+ assert entity.question_type == "document-rag"
+
+ def test_agent_question(self):
+ triples = [
+ ("urn:q:3", RDF_TYPE, TG_QUESTION),
+ ("urn:q:3", RDF_TYPE, TG_AGENT_QUESTION),
+ ("urn:q:3", TG_QUERY, "Agent query"),
+ ]
+ entity = ExplainEntity.from_triples("urn:q:3", triples)
+ assert isinstance(entity, Question)
+ assert entity.question_type == "agent"
+
+ def test_exploration(self):
+ triples = [
+ ("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
+ ("urn:exp:1", TG_EDGE_COUNT, "15"),
+ ]
+ entity = ExplainEntity.from_triples("urn:exp:1", triples)
+ assert isinstance(entity, Exploration)
+ assert entity.edge_count == 15
+
+ def test_exploration_with_chunk_count(self):
+ triples = [
+ ("urn:exp:2", RDF_TYPE, TG_EXPLORATION),
+ ("urn:exp:2", TG_CHUNK_COUNT, "5"),
+ ]
+ entity = ExplainEntity.from_triples("urn:exp:2", triples)
+ assert isinstance(entity, Exploration)
+ assert entity.chunk_count == 5
+
+ def test_exploration_invalid_count(self):
+ triples = [
+ ("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
+ ("urn:exp:3", TG_EDGE_COUNT, "not-a-number"),
+ ]
+ entity = ExplainEntity.from_triples("urn:exp:3", triples)
+ assert isinstance(entity, Exploration)
+ assert entity.edge_count == 0
+
+ def test_focus(self):
+ triples = [
+ ("urn:foc:1", RDF_TYPE, TG_FOCUS),
+ ("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:1"),
+ ("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:2"),
+ ]
+ entity = ExplainEntity.from_triples("urn:foc:1", triples)
+ assert isinstance(entity, Focus)
+ assert len(entity.selected_edge_uris) == 2
+ assert "urn:edge:1" in entity.selected_edge_uris
+ assert "urn:edge:2" in entity.selected_edge_uris
+
+ def test_synthesis_with_content(self):
+ triples = [
+ ("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
+ ("urn:syn:1", TG_CONTENT, "The answer is 42"),
+ ]
+ entity = ExplainEntity.from_triples("urn:syn:1", triples)
+ assert isinstance(entity, Synthesis)
+ assert entity.content == "The answer is 42"
+ assert entity.document_uri == ""
+
+ def test_synthesis_with_document(self):
+ triples = [
+ ("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
+ ("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"),
+ ]
+ entity = ExplainEntity.from_triples("urn:syn:2", triples)
+ assert isinstance(entity, Synthesis)
+ assert entity.document_uri == "urn:doc:answer-1"
+
+ def test_analysis(self):
+ triples = [
+ ("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
+ ("urn:ana:1", TG_THOUGHT, "I should search"),
+ ("urn:ana:1", TG_ACTION, "graph-rag-query"),
+ ("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
+ ("urn:ana:1", TG_OBSERVATION, "Found results"),
+ ]
+ entity = ExplainEntity.from_triples("urn:ana:1", triples)
+ assert isinstance(entity, Analysis)
+ assert entity.thought == "I should search"
+ assert entity.action == "graph-rag-query"
+ assert entity.arguments == '{"query": "test"}'
+ assert entity.observation == "Found results"
+
+ def test_analysis_with_document_refs(self):
+ triples = [
+ ("urn:ana:2", RDF_TYPE, TG_ANALYSIS),
+ ("urn:ana:2", TG_ACTION, "search"),
+ ("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"),
+ ("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"),
+ ]
+ entity = ExplainEntity.from_triples("urn:ana:2", triples)
+ assert isinstance(entity, Analysis)
+ assert entity.thought_document_uri == "urn:doc:thought-1"
+ assert entity.observation_document_uri == "urn:doc:obs-1"
+
+ def test_conclusion_with_answer(self):
+ triples = [
+ ("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
+ ("urn:conc:1", TG_ANSWER, "The final answer"),
+ ]
+ entity = ExplainEntity.from_triples("urn:conc:1", triples)
+ assert isinstance(entity, Conclusion)
+ assert entity.answer == "The final answer"
+
+ def test_conclusion_with_document(self):
+ triples = [
+ ("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
+ ("urn:conc:2", TG_DOCUMENT, "urn:doc:final"),
+ ]
+ entity = ExplainEntity.from_triples("urn:conc:2", triples)
+ assert isinstance(entity, Conclusion)
+ assert entity.document_uri == "urn:doc:final"
+
+ def test_unknown_type(self):
+ triples = [
+ ("urn:x:1", RDF_TYPE, "http://example.com/UnknownType"),
+ ]
+ entity = ExplainEntity.from_triples("urn:x:1", triples)
+ assert isinstance(entity, ExplainEntity)
+ assert entity.entity_type == "unknown"
+
+
+# ---------------------------------------------------------------------------
+# parse_edge_selection_triples
+# ---------------------------------------------------------------------------
+
+class TestParseEdgeSelectionTriples:
+
+ def test_with_edge_and_reasoning(self):
+ triples = [
+ ("urn:edge:1", TG_EDGE, {"s": "Alice", "p": "knows", "o": "Bob"}),
+ ("urn:edge:1", TG_REASONING, "Alice and Bob are connected"),
+ ]
+ result = parse_edge_selection_triples(triples)
+ assert isinstance(result, EdgeSelection)
+ assert result.uri == "urn:edge:1"
+ assert result.edge == {"s": "Alice", "p": "knows", "o": "Bob"}
+ assert result.reasoning == "Alice and Bob are connected"
+
+ def test_with_edge_only(self):
+ triples = [
+ ("urn:edge:2", TG_EDGE, {"s": "A", "p": "r", "o": "B"}),
+ ]
+ result = parse_edge_selection_triples(triples)
+ assert result.edge is not None
+ assert result.reasoning == ""
+
+ def test_with_reasoning_only(self):
+ triples = [
+ ("urn:edge:3", TG_REASONING, "some reason"),
+ ]
+ result = parse_edge_selection_triples(triples)
+ assert result.edge is None
+ assert result.reasoning == "some reason"
+
+ def test_empty_triples(self):
+ result = parse_edge_selection_triples([])
+ assert result.uri == ""
+ assert result.edge is None
+ assert result.reasoning == ""
+
+ def test_edge_must_be_dict(self):
+ """Non-dict values for TG_EDGE should not be treated as edges."""
+ triples = [
+ ("urn:edge:4", TG_EDGE, "not-a-dict"),
+ ]
+ result = parse_edge_selection_triples(triples)
+ assert result.edge is None
+
+
+# ---------------------------------------------------------------------------
+# extract_term_value
+# ---------------------------------------------------------------------------
+
+class TestExtractTermValue:
+
+ def test_iri_short_format(self):
+ assert extract_term_value({"t": "i", "i": "urn:test"}) == "urn:test"
+
+ def test_iri_long_format(self):
+ assert extract_term_value({"type": "i", "iri": "urn:test"}) == "urn:test"
+
+ def test_literal_short_format(self):
+ assert extract_term_value({"t": "l", "v": "hello"}) == "hello"
+
+ def test_literal_long_format(self):
+ assert extract_term_value({"type": "l", "value": "hello"}) == "hello"
+
+ def test_quoted_triple(self):
+ term = {
+ "t": "t",
+ "tr": {
+ "s": {"t": "i", "i": "urn:s"},
+ "p": {"t": "i", "i": "urn:p"},
+ "o": {"t": "i", "i": "urn:o"},
+ }
+ }
+ result = extract_term_value(term)
+ assert result == {"s": "urn:s", "p": "urn:p", "o": "urn:o"}
+
+ def test_quoted_triple_long_format(self):
+ term = {
+ "type": "t",
+ "triple": {
+ "s": {"type": "i", "iri": "urn:s"},
+ "p": {"type": "i", "iri": "urn:p"},
+ "o": {"type": "l", "value": "val"},
+ }
+ }
+ result = extract_term_value(term)
+ assert result == {"s": "urn:s", "p": "urn:p", "o": "val"}
+
+ def test_unknown_type_fallback(self):
+ result = extract_term_value({"t": "x", "i": "urn:fallback"})
+ assert result == "urn:fallback"
+
+
+# ---------------------------------------------------------------------------
+# wire_triples_to_tuples
+# ---------------------------------------------------------------------------
+
+class TestWireTriplesToTuples:
+
+ def test_basic_conversion(self):
+ wire = [
+ {
+ "s": {"t": "i", "i": "urn:s1"},
+ "p": {"t": "i", "i": "urn:p1"},
+ "o": {"t": "l", "v": "value1"},
+ },
+ ]
+ result = wire_triples_to_tuples(wire)
+ assert len(result) == 1
+ assert result[0] == ("urn:s1", "urn:p1", "value1")
+
+ def test_multiple_triples(self):
+ wire = [
+ {
+ "s": {"t": "i", "i": "urn:s1"},
+ "p": {"t": "i", "i": "urn:p1"},
+ "o": {"t": "l", "v": "v1"},
+ },
+ {
+ "s": {"t": "i", "i": "urn:s2"},
+ "p": {"t": "i", "i": "urn:p2"},
+ "o": {"t": "i", "i": "urn:o2"},
+ },
+ ]
+ result = wire_triples_to_tuples(wire)
+ assert len(result) == 2
+ assert result[0] == ("urn:s1", "urn:p1", "v1")
+ assert result[1] == ("urn:s2", "urn:p2", "urn:o2")
+
+ def test_empty_list(self):
+ assert wire_triples_to_tuples([]) == []
+
+ def test_missing_fields(self):
+ wire = [{"s": {}, "p": {}, "o": {}}]
+ result = wire_triples_to_tuples(wire)
+ assert len(result) == 1
+
+
+# ---------------------------------------------------------------------------
+# ExplainabilityClient
+# ---------------------------------------------------------------------------
+
+def _make_wire_triples(tuples):
+ """Convert (s, p, o) tuples to wire format for mocking."""
+ result = []
+ for s, p, o in tuples:
+ entry = {
+ "s": {"t": "i", "i": s},
+ "p": {"t": "i", "i": p},
+ }
+ if o.startswith("urn:") or o.startswith("http"):
+ entry["o"] = {"t": "i", "i": o}
+ else:
+ entry["o"] = {"t": "l", "v": o}
+ result.append(entry)
+ return result
+
+
+class TestExplainabilityClientFetchEntity:
+
+ def test_fetch_question_entity(self):
+ wire = _make_wire_triples([
+ ("urn:q:1", RDF_TYPE, TG_QUESTION),
+ ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
+ ("urn:q:1", TG_QUERY, "What is AI?"),
+ ("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"),
+ ])
+
+ mock_flow = MagicMock()
+ # Return same results twice for quiescence
+ mock_flow.triples_query.side_effect = [wire, wire]
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ entity = client.fetch_entity("urn:q:1", graph="urn:graph:retrieval")
+
+ assert isinstance(entity, Question)
+ assert entity.query == "What is AI?"
+ assert entity.question_type == "graph-rag"
+
+ def test_fetch_returns_none_when_no_data(self):
+ mock_flow = MagicMock()
+ mock_flow.triples_query.return_value = []
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
+ entity = client.fetch_entity("urn:nonexistent")
+
+ assert entity is None
+
+ def test_fetch_retries_on_empty_results(self):
+ wire = _make_wire_triples([
+ ("urn:q:1", RDF_TYPE, TG_QUESTION),
+ ("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
+ ("urn:q:1", TG_QUERY, "Q"),
+ ])
+
+ mock_flow = MagicMock()
+ # Empty, then data, then same data (stable)
+ mock_flow.triples_query.side_effect = [[], wire, wire]
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ entity = client.fetch_entity("urn:q:1")
+
+ assert isinstance(entity, Question)
+ assert mock_flow.triples_query.call_count == 3
+
+
+class TestExplainabilityClientResolveLabel:
+
+ def test_resolve_label_found(self):
+ mock_flow = MagicMock()
+ mock_flow.triples_query.return_value = _make_wire_triples([
+ ("urn:entity:1", RDFS_LABEL, "Entity One"),
+ ])
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ label = client.resolve_label("urn:entity:1")
+ assert label == "Entity One"
+
+ def test_resolve_label_not_found(self):
+ mock_flow = MagicMock()
+ mock_flow.triples_query.return_value = []
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ label = client.resolve_label("urn:entity:1")
+ assert label == "urn:entity:1"
+
+ def test_resolve_label_cached(self):
+ mock_flow = MagicMock()
+ mock_flow.triples_query.return_value = _make_wire_triples([
+ ("urn:entity:1", RDFS_LABEL, "Entity One"),
+ ])
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ client.resolve_label("urn:entity:1")
+ client.resolve_label("urn:entity:1")
+
+ # Only one query should be made
+ assert mock_flow.triples_query.call_count == 1
+
+ def test_resolve_label_non_uri(self):
+ mock_flow = MagicMock()
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ assert client.resolve_label("plain text") == "plain text"
+ assert client.resolve_label("") == ""
+ mock_flow.triples_query.assert_not_called()
+
+ def test_resolve_edge_labels(self):
+ mock_flow = MagicMock()
+
+ def mock_query(s=None, p=None, **kwargs):
+ labels = {
+ "urn:e:Alice": "Alice",
+ "urn:r:knows": "knows",
+ "urn:e:Bob": "Bob",
+ }
+ if s in labels:
+ return _make_wire_triples([(s, RDFS_LABEL, labels[s])])
+ return []
+
+ mock_flow.triples_query.side_effect = mock_query
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ s, p, o = client.resolve_edge_labels(
+ {"s": "urn:e:Alice", "p": "urn:r:knows", "o": "urn:e:Bob"}
+ )
+ assert s == "Alice"
+ assert p == "knows"
+ assert o == "Bob"
+
+
+class TestExplainabilityClientContentFetching:
+
+ def test_fetch_synthesis_inline_content(self):
+ mock_flow = MagicMock()
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+
+ synthesis = Synthesis(uri="urn:syn:1", content="inline answer")
+ result = client.fetch_synthesis_content(synthesis, api=None)
+ assert result == "inline answer"
+
+ def test_fetch_synthesis_truncated_content(self):
+ mock_flow = MagicMock()
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+
+ long_content = "x" * 20000
+ synthesis = Synthesis(uri="urn:syn:1", content=long_content)
+ result = client.fetch_synthesis_content(synthesis, api=None, max_content=100)
+ assert len(result) < 20000
+ assert result.endswith("... [truncated]")
+
+ def test_fetch_synthesis_from_librarian(self):
+ mock_flow = MagicMock()
+ mock_api = MagicMock()
+ mock_library = MagicMock()
+ mock_api.library.return_value = mock_library
+ mock_library.get_document_content.return_value = b"librarian content"
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ synthesis = Synthesis(
+ uri="urn:syn:1",
+ document_uri="urn:document:abc123"
+ )
+ result = client.fetch_synthesis_content(synthesis, api=mock_api)
+ assert result == "librarian content"
+
+ def test_fetch_synthesis_no_content_or_document(self):
+ mock_flow = MagicMock()
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+
+ synthesis = Synthesis(uri="urn:syn:1")
+ result = client.fetch_synthesis_content(synthesis, api=None)
+ assert result == ""
+
+ def test_fetch_conclusion_inline(self):
+ mock_flow = MagicMock()
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+
+ conclusion = Conclusion(uri="urn:conc:1", answer="42")
+ result = client.fetch_conclusion_content(conclusion, api=None)
+ assert result == "42"
+
+ def test_fetch_analysis_content_from_librarian(self):
+ mock_flow = MagicMock()
+ mock_api = MagicMock()
+ mock_library = MagicMock()
+ mock_api.library.return_value = mock_library
+ mock_library.get_document_content.side_effect = [
+ b"thought content",
+ b"observation content",
+ ]
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ analysis = Analysis(
+ uri="urn:ana:1",
+ action="search",
+ thought_document_uri="urn:doc:thought",
+ observation_document_uri="urn:doc:obs",
+ )
+ client.fetch_analysis_content(analysis, api=mock_api)
+ assert analysis.thought == "thought content"
+ assert analysis.observation == "observation content"
+
+ def test_fetch_analysis_skips_when_inline_exists(self):
+ mock_flow = MagicMock()
+ mock_api = MagicMock()
+
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ analysis = Analysis(
+ uri="urn:ana:1",
+ action="search",
+ thought="already have thought",
+ observation="already have observation",
+ thought_document_uri="urn:doc:thought",
+ observation_document_uri="urn:doc:obs",
+ )
+ client.fetch_analysis_content(analysis, api=mock_api)
+ # Should not call librarian since inline content exists
+ mock_api.library.assert_not_called()
+
+
+class TestExplainabilityClientDetectSessionType:
+
+ def test_detect_agent_from_uri(self):
+ mock_flow = MagicMock()
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ assert client.detect_session_type("urn:trustgraph:agent:abc") == "agent"
+
+ def test_detect_graphrag_from_uri(self):
+ mock_flow = MagicMock()
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ assert client.detect_session_type("urn:trustgraph:question:abc") == "graphrag"
+
+ def test_detect_docrag_from_uri(self):
+ mock_flow = MagicMock()
+ client = ExplainabilityClient(mock_flow, retry_delay=0.0)
+ assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"
diff --git a/tests/unit/test_provenance/test_triples.py b/tests/unit/test_provenance/test_triples.py
new file mode 100644
index 00000000..91074097
--- /dev/null
+++ b/tests/unit/test_provenance/test_triples.py
@@ -0,0 +1,785 @@
+"""
+Tests for provenance triple builder functions (extraction-time and query-time).
+"""
+
+import pytest
+from unittest.mock import patch
+
+from trustgraph.schema import Triple, Term, IRI, LITERAL, TRIPLE
+
+from trustgraph.provenance.triples import (
+ set_graph,
+ document_triples,
+ derived_entity_triples,
+ subgraph_provenance_triples,
+ question_triples,
+ exploration_triples,
+ focus_triples,
+ synthesis_triples,
+ docrag_question_triples,
+ docrag_exploration_triples,
+ docrag_synthesis_triples,
+)
+
+from trustgraph.provenance.namespaces import (
+ RDF_TYPE, RDFS_LABEL,
+ PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT,
+ PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
+ PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME,
+ DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR,
+ TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER,
+ TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH,
+ TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
+ TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
+ TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
+ TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
+ TG_CONTENT, TG_DOCUMENT,
+ TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
+ TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
+ TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
+ GRAPH_SOURCE, GRAPH_RETRIEVAL,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def find_triple(triples, predicate, subject=None):
+ """Find first triple matching predicate (and optionally subject)."""
+ for t in triples:
+ if t.p.iri == predicate:
+ if subject is None or t.s.iri == subject:
+ return t
+ return None
+
+
+def find_triples(triples, predicate, subject=None):
+ """Find all triples matching predicate (and optionally subject)."""
+ return [
+ t for t in triples
+ if t.p.iri == predicate and (subject is None or t.s.iri == subject)
+ ]
+
+
+def has_type(triples, subject, rdf_type):
+ """Check if subject has rdf:type rdf_type."""
+ for t in triples:
+ if (t.s.iri == subject and t.p.iri == RDF_TYPE
+ and t.o.type == IRI and t.o.iri == rdf_type):
+ return True
+ return False
+
+
+# ---------------------------------------------------------------------------
+# set_graph
+# ---------------------------------------------------------------------------
+
+class TestSetGraph:
+
+ def test_sets_graph_on_all_triples(self):
+ triples = [
+ Triple(
+ s=Term(type=IRI, iri="urn:s1"),
+ p=Term(type=IRI, iri="urn:p1"),
+ o=Term(type=LITERAL, value="v1"),
+ ),
+ Triple(
+ s=Term(type=IRI, iri="urn:s2"),
+ p=Term(type=IRI, iri="urn:p2"),
+ o=Term(type=LITERAL, value="v2"),
+ ),
+ ]
+ result = set_graph(triples, GRAPH_RETRIEVAL)
+ assert len(result) == 2
+ for t in result:
+ assert t.g == GRAPH_RETRIEVAL
+
+ def test_does_not_modify_originals(self):
+ original = Triple(
+ s=Term(type=IRI, iri="urn:s"),
+ p=Term(type=IRI, iri="urn:p"),
+ o=Term(type=LITERAL, value="v"),
+ )
+ result = set_graph([original], "urn:graph:test")
+ assert original.g is None
+ assert result[0].g == "urn:graph:test"
+
+ def test_empty_list(self):
+ result = set_graph([], GRAPH_SOURCE)
+ assert result == []
+
+ def test_preserves_spo(self):
+ original = Triple(
+ s=Term(type=IRI, iri="urn:s"),
+ p=Term(type=IRI, iri="urn:p"),
+ o=Term(type=LITERAL, value="hello"),
+ )
+ result = set_graph([original], "urn:g")[0]
+ assert result.s.iri == "urn:s"
+ assert result.p.iri == "urn:p"
+ assert result.o.value == "hello"
+
+
+# ---------------------------------------------------------------------------
+# document_triples
+# ---------------------------------------------------------------------------
+
+class TestDocumentTriples:
+
+ DOC_URI = "https://example.com/doc/abc"
+
+ def test_minimal_document(self):
+ triples = document_triples(self.DOC_URI)
+ assert has_type(triples, self.DOC_URI, PROV_ENTITY)
+ assert has_type(triples, self.DOC_URI, TG_DOCUMENT_TYPE)
+ assert len(triples) == 2
+
+ def test_with_title(self):
+ triples = document_triples(self.DOC_URI, title="My Doc")
+ title_t = find_triple(triples, DC_TITLE)
+ assert title_t is not None
+ assert title_t.o.value == "My Doc"
+ # Title also creates an rdfs:label
+ label_t = find_triple(triples, RDFS_LABEL)
+ assert label_t is not None
+ assert label_t.o.value == "My Doc"
+
+ def test_with_source(self):
+ triples = document_triples(self.DOC_URI, source="https://source.com/f.pdf")
+ source_t = find_triple(triples, DC_SOURCE)
+ assert source_t is not None
+ assert source_t.o.type == IRI
+ assert source_t.o.iri == "https://source.com/f.pdf"
+
+ def test_with_date(self):
+ triples = document_triples(self.DOC_URI, date="2024-01-15")
+ date_t = find_triple(triples, DC_DATE)
+ assert date_t is not None
+ assert date_t.o.value == "2024-01-15"
+
+ def test_with_creator(self):
+ triples = document_triples(self.DOC_URI, creator="Alice")
+ creator_t = find_triple(triples, DC_CREATOR)
+ assert creator_t is not None
+ assert creator_t.o.value == "Alice"
+
+ def test_with_page_count(self):
+ triples = document_triples(self.DOC_URI, page_count=42)
+ pc_t = find_triple(triples, TG_PAGE_COUNT)
+ assert pc_t is not None
+ assert pc_t.o.value == "42"
+
+ def test_with_page_count_zero(self):
+ triples = document_triples(self.DOC_URI, page_count=0)
+ pc_t = find_triple(triples, TG_PAGE_COUNT)
+ assert pc_t is not None
+ assert pc_t.o.value == "0"
+
+ def test_with_mime_type(self):
+ triples = document_triples(self.DOC_URI, mime_type="application/pdf")
+ mt_t = find_triple(triples, TG_MIME_TYPE)
+ assert mt_t is not None
+ assert mt_t.o.value == "application/pdf"
+
+ def test_all_metadata(self):
+ triples = document_triples(
+ self.DOC_URI,
+ title="Test",
+ source="https://s.com",
+ date="2024-01-01",
+ creator="Bob",
+ page_count=10,
+ mime_type="application/pdf",
+ )
+ # 2 type triples + title + label + source + date + creator + page_count + mime_type
+ assert len(triples) == 9
+
+ def test_subject_is_doc_uri(self):
+ triples = document_triples(self.DOC_URI, title="T")
+ for t in triples:
+ assert t.s.iri == self.DOC_URI
+
+
+# ---------------------------------------------------------------------------
+# derived_entity_triples
+# ---------------------------------------------------------------------------
+
+class TestDerivedEntityTriples:
+
+ ENTITY_URI = "https://example.com/doc/abc/p1"
+ PARENT_URI = "https://example.com/doc/abc"
+
+ def test_page_entity_has_page_type(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "pdf-extractor", "1.0",
+ page_number=1,
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
+ assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
+
+ def test_chunk_entity_has_chunk_type(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "chunker", "1.0",
+ chunk_index=0,
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ assert has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE)
+
+ def test_no_specific_type_without_page_or_chunk(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "component", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
+ assert not has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
+ assert not has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE)
+
+ def test_was_derived_from_parent(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "pdf-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ENTITY_URI)
+ assert derived is not None
+ assert derived.o.iri == self.PARENT_URI
+
+ def test_activity_created(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "pdf-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ # Entity was generated by an activity
+ gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI)
+ assert gen is not None
+ act_uri = gen.o.iri
+
+ # Activity has correct type and metadata
+ assert has_type(triples, act_uri, PROV_ACTIVITY)
+
+ # Activity used the parent
+ used = find_triple(triples, PROV_USED, act_uri)
+ assert used is not None
+ assert used.o.iri == self.PARENT_URI
+
+ # Activity has component version
+ version = find_triple(triples, TG_COMPONENT_VERSION, act_uri)
+ assert version is not None
+ assert version.o.value == "1.0"
+
+ def test_agent_created(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "pdf-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ # Find the agent URI via wasAssociatedWith
+ gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI)
+ act_uri = gen.o.iri
+ assoc = find_triple(triples, PROV_WAS_ASSOCIATED_WITH, act_uri)
+ assert assoc is not None
+ agt_uri = assoc.o.iri
+
+ assert has_type(triples, agt_uri, PROV_AGENT)
+ label = find_triple(triples, RDFS_LABEL, agt_uri)
+ assert label is not None
+ assert label.o.value == "pdf-extractor"
+
+ def test_timestamp_recorded(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "pdf-extractor", "1.0",
+ timestamp="2024-06-15T12:30:00Z",
+ )
+ ts = find_triple(triples, PROV_STARTED_AT_TIME)
+ assert ts is not None
+ assert ts.o.value == "2024-06-15T12:30:00Z"
+
+ def test_default_timestamp_generated(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "pdf-extractor", "1.0",
+ )
+ ts = find_triple(triples, PROV_STARTED_AT_TIME)
+ assert ts is not None
+ assert len(ts.o.value) > 0
+
+ def test_optional_label(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "pdf-extractor", "1.0",
+ label="Page 1",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ label = find_triple(triples, RDFS_LABEL, self.ENTITY_URI)
+ assert label is not None
+ assert label.o.value == "Page 1"
+
+ def test_page_number_recorded(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "pdf-extractor", "1.0",
+ page_number=3,
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ pn = find_triple(triples, TG_PAGE_NUMBER, self.ENTITY_URI)
+ assert pn is not None
+ assert pn.o.value == "3"
+
+ def test_chunk_metadata_recorded(self):
+ triples = derived_entity_triples(
+ self.ENTITY_URI, self.PARENT_URI,
+ "chunker", "2.0",
+ chunk_index=5,
+ char_offset=1000,
+ char_length=500,
+ chunk_size=512,
+ chunk_overlap=64,
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ ci = find_triple(triples, TG_CHUNK_INDEX, self.ENTITY_URI)
+ assert ci is not None and ci.o.value == "5"
+
+ co = find_triple(triples, TG_CHAR_OFFSET, self.ENTITY_URI)
+ assert co is not None and co.o.value == "1000"
+
+ cl = find_triple(triples, TG_CHAR_LENGTH, self.ENTITY_URI)
+ assert cl is not None and cl.o.value == "500"
+
+ # chunk_size and chunk_overlap are on the activity, not the entity
+ cs = find_triple(triples, TG_CHUNK_SIZE)
+ assert cs is not None and cs.o.value == "512"
+
+ ov = find_triple(triples, TG_CHUNK_OVERLAP)
+ assert ov is not None and ov.o.value == "64"
+
+
+# ---------------------------------------------------------------------------
+# subgraph_provenance_triples
+# ---------------------------------------------------------------------------
+
+class TestSubgraphProvenanceTriples:
+
+ SG_URI = "https://trustgraph.ai/subgraph/test-sg"
+ CHUNK_URI = "https://example.com/doc/abc/p1/c0"
+
+ def _make_extracted_triple(self, s="urn:e:Alice", p="urn:r:knows", o="urn:e:Bob"):
+ return Triple(
+ s=Term(type=IRI, iri=s),
+ p=Term(type=IRI, iri=p),
+ o=Term(type=IRI, iri=o),
+ )
+
+ def test_contains_quoted_triples(self):
+ extracted = [self._make_extracted_triple()]
+ triples = subgraph_provenance_triples(
+ self.SG_URI, extracted, self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
+ assert len(contains) == 1
+ assert contains[0].o.type == TRIPLE
+ assert contains[0].o.triple.s.iri == "urn:e:Alice"
+ assert contains[0].o.triple.p.iri == "urn:r:knows"
+ assert contains[0].o.triple.o.iri == "urn:e:Bob"
+
+ def test_multiple_extracted_triples(self):
+ extracted = [
+ self._make_extracted_triple("urn:e:A", "urn:r:x", "urn:e:B"),
+ self._make_extracted_triple("urn:e:C", "urn:r:y", "urn:e:D"),
+ self._make_extracted_triple("urn:e:E", "urn:r:z", "urn:e:F"),
+ ]
+ triples = subgraph_provenance_triples(
+ self.SG_URI, extracted, self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
+ assert len(contains) == 3
+
+ def test_empty_extracted_triples(self):
+ triples = subgraph_provenance_triples(
+ self.SG_URI, [], self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
+ assert len(contains) == 0
+ # Should still have subgraph provenance metadata
+ assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE)
+
+ def test_subgraph_has_correct_types(self):
+ triples = subgraph_provenance_triples(
+ self.SG_URI, [], self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ assert has_type(triples, self.SG_URI, PROV_ENTITY)
+ assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE)
+
+ def test_derived_from_chunk(self):
+ triples = subgraph_provenance_triples(
+ self.SG_URI, [], self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SG_URI)
+ assert derived is not None
+ assert derived.o.iri == self.CHUNK_URI
+
+ def test_activity_and_agent(self):
+ triples = subgraph_provenance_triples(
+ self.SG_URI, [], self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.SG_URI)
+ assert gen is not None
+ act_uri = gen.o.iri
+
+ assert has_type(triples, act_uri, PROV_ACTIVITY)
+
+ used = find_triple(triples, PROV_USED, act_uri)
+ assert used is not None
+ assert used.o.iri == self.CHUNK_URI
+
+ version = find_triple(triples, TG_COMPONENT_VERSION, act_uri)
+ assert version is not None
+ assert version.o.value == "1.0"
+
+ def test_optional_llm_model(self):
+ triples = subgraph_provenance_triples(
+ self.SG_URI, [], self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ llm_model="claude-3-opus",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ llm = find_triple(triples, TG_LLM_MODEL)
+ assert llm is not None
+ assert llm.o.value == "claude-3-opus"
+
+ def test_no_llm_model_when_omitted(self):
+ triples = subgraph_provenance_triples(
+ self.SG_URI, [], self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ llm = find_triple(triples, TG_LLM_MODEL)
+ assert llm is None
+
+ def test_optional_ontology(self):
+ triples = subgraph_provenance_triples(
+ self.SG_URI, [], self.CHUNK_URI,
+ "kg-extractor", "1.0",
+ ontology_uri="https://example.com/ontology/v1",
+ timestamp="2024-01-01T00:00:00Z",
+ )
+ ont = find_triple(triples, TG_ONTOLOGY)
+ assert ont is not None
+ assert ont.o.type == IRI
+ assert ont.o.iri == "https://example.com/ontology/v1"
+
+
+# ---------------------------------------------------------------------------
+# GraphRAG query-time triples
+# ---------------------------------------------------------------------------
+
+class TestQuestionTriples:
+
+ Q_URI = "urn:trustgraph:question:test-session"
+
+ def test_question_types(self):
+ triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
+ assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
+ assert has_type(triples, self.Q_URI, TG_QUESTION)
+ assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION)
+
+ def test_question_query_text(self):
+ triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
+ query = find_triple(triples, TG_QUERY, self.Q_URI)
+ assert query is not None
+ assert query.o.value == "What is AI?"
+
+ def test_question_timestamp(self):
+ triples = question_triples(self.Q_URI, "Q", "2024-06-15T10:00:00Z")
+ ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI)
+ assert ts is not None
+ assert ts.o.value == "2024-06-15T10:00:00Z"
+
+ def test_question_default_timestamp(self):
+ triples = question_triples(self.Q_URI, "Q")
+ ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI)
+ assert ts is not None
+ assert len(ts.o.value) > 0
+
+ def test_question_label(self):
+ triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
+ label = find_triple(triples, RDFS_LABEL, self.Q_URI)
+ assert label is not None
+ assert label.o.value == "GraphRAG Question"
+
+ def test_question_triple_count(self):
+ triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
+ assert len(triples) == 6
+
+
+class TestExplorationTriples:
+
+ EXP_URI = "urn:trustgraph:prov:exploration:test-session"
+ Q_URI = "urn:trustgraph:question:test-session"
+
+ def test_exploration_types(self):
+ triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
+ assert has_type(triples, self.EXP_URI, PROV_ENTITY)
+ assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
+
+ def test_exploration_generated_by_question(self):
+ triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
+ gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
+ assert gen is not None
+ assert gen.o.iri == self.Q_URI
+
+ def test_exploration_edge_count(self):
+ triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
+ ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
+ assert ec is not None
+ assert ec.o.value == "15"
+
+ def test_exploration_zero_edges(self):
+ triples = exploration_triples(self.EXP_URI, self.Q_URI, 0)
+ ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
+ assert ec is not None
+ assert ec.o.value == "0"
+
+ def test_exploration_triple_count(self):
+ triples = exploration_triples(self.EXP_URI, self.Q_URI, 10)
+ assert len(triples) == 5
+
+
+class TestFocusTriples:
+
+ FOC_URI = "urn:trustgraph:prov:focus:test-session"
+ EXP_URI = "urn:trustgraph:prov:exploration:test-session"
+ SESSION_ID = "test-session"
+
+ def test_focus_types(self):
+ triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
+ assert has_type(triples, self.FOC_URI, PROV_ENTITY)
+ assert has_type(triples, self.FOC_URI, TG_FOCUS)
+
+ def test_focus_derived_from_exploration(self):
+ triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
+ derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FOC_URI)
+ assert derived is not None
+ assert derived.o.iri == self.EXP_URI
+
+ def test_focus_no_edges(self):
+ triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
+ selected = find_triples(triples, TG_SELECTED_EDGE)
+ assert len(selected) == 0
+
+ def test_focus_with_edges_and_reasoning(self):
+ edges = [
+ {
+ "edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"),
+ "reasoning": "Alice is connected to Bob",
+ },
+ {
+ "edge": ("urn:e:Bob", "urn:r:worksAt", "urn:e:Acme"),
+ "reasoning": "Bob works at Acme",
+ },
+ ]
+ triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
+
+ # Two selectedEdge links
+ selected = find_triples(triples, TG_SELECTED_EDGE, self.FOC_URI)
+ assert len(selected) == 2
+
+ # Each edge selection has a quoted triple
+ edge_triples = find_triples(triples, TG_EDGE)
+ assert len(edge_triples) == 2
+ for et in edge_triples:
+ assert et.o.type == TRIPLE
+
+ # Each edge selection has reasoning
+ reasoning_triples = find_triples(triples, TG_REASONING)
+ assert len(reasoning_triples) == 2
+
+ def test_focus_edge_without_reasoning(self):
+ edges = [
+ {"edge": ("urn:e:A", "urn:r:x", "urn:e:B"), "reasoning": ""},
+ ]
+ triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
+ reasoning = find_triples(triples, TG_REASONING)
+ assert len(reasoning) == 0
+
+ def test_focus_edge_without_edge_data(self):
+ edges = [
+ {"edge": None, "reasoning": "some reasoning"},
+ ]
+ triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
+ selected = find_triples(triples, TG_SELECTED_EDGE)
+ assert len(selected) == 0
+
+ def test_focus_quoted_triple_content(self):
+ edges = [
+ {
+ "edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"),
+ "reasoning": "test",
+ },
+ ]
+ triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
+ edge_t = find_triple(triples, TG_EDGE)
+ qt = edge_t.o.triple
+ assert qt.s.iri == "urn:e:Alice"
+ assert qt.p.iri == "urn:r:knows"
+ assert qt.o.iri == "urn:e:Bob"
+
+
+class TestSynthesisTriples:
+
+ SYN_URI = "urn:trustgraph:prov:synthesis:test-session"
+ FOC_URI = "urn:trustgraph:prov:focus:test-session"
+
+ def test_synthesis_types(self):
+ triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
+ assert has_type(triples, self.SYN_URI, PROV_ENTITY)
+ assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
+
+ def test_synthesis_derived_from_focus(self):
+ triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
+ derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
+ assert derived is not None
+ assert derived.o.iri == self.FOC_URI
+
+ def test_synthesis_with_inline_content(self):
+ triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42")
+ content = find_triple(triples, TG_CONTENT, self.SYN_URI)
+ assert content is not None
+ assert content.o.value == "The answer is 42"
+
+ def test_synthesis_with_document_reference(self):
+ triples = synthesis_triples(
+ self.SYN_URI, self.FOC_URI,
+ document_id="urn:trustgraph:question:abc/answer",
+ )
+ doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
+ assert doc is not None
+ assert doc.o.type == IRI
+ assert doc.o.iri == "urn:trustgraph:question:abc/answer"
+
+ def test_synthesis_document_takes_precedence(self):
+ """When both document_id and answer_text are provided, document_id wins."""
+ triples = synthesis_triples(
+ self.SYN_URI, self.FOC_URI,
+ answer_text="inline",
+ document_id="urn:doc:123",
+ )
+ doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
+ assert doc is not None
+ content = find_triple(triples, TG_CONTENT, self.SYN_URI)
+ assert content is None
+
+ def test_synthesis_no_content_or_document(self):
+ triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
+ content = find_triple(triples, TG_CONTENT, self.SYN_URI)
+ doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
+ assert content is None
+ assert doc is None
+
+
+# ---------------------------------------------------------------------------
+# DocumentRAG query-time triples
+# ---------------------------------------------------------------------------
+
+class TestDocRagQuestionTriples:
+
+ Q_URI = "urn:trustgraph:docrag:test-session"
+
+ def test_docrag_question_types(self):
+ triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z")
+ assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
+ assert has_type(triples, self.Q_URI, TG_QUESTION)
+ assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION)
+
+ def test_docrag_question_label(self):
+ triples = docrag_question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
+ label = find_triple(triples, RDFS_LABEL, self.Q_URI)
+ assert label.o.value == "DocumentRAG Question"
+
+ def test_docrag_question_query_text(self):
+ triples = docrag_question_triples(self.Q_URI, "search query", "2024-01-01T00:00:00Z")
+ query = find_triple(triples, TG_QUERY, self.Q_URI)
+ assert query.o.value == "search query"
+
+
+class TestDocRagExplorationTriples:
+
+ EXP_URI = "urn:trustgraph:docrag:test/exploration"
+ Q_URI = "urn:trustgraph:docrag:test"
+
+ def test_docrag_exploration_types(self):
+ triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
+ assert has_type(triples, self.EXP_URI, PROV_ENTITY)
+ assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
+
+ def test_docrag_exploration_generated_by(self):
+ triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
+ gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
+ assert gen.o.iri == self.Q_URI
+
+ def test_docrag_exploration_chunk_count(self):
+ triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7)
+ cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
+ assert cc.o.value == "7"
+
+ def test_docrag_exploration_without_chunk_ids(self):
+ triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3)
+ chunks = find_triples(triples, TG_SELECTED_CHUNK)
+ assert len(chunks) == 0
+
+ def test_docrag_exploration_with_chunk_ids(self):
+ chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
+ triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids)
+ chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
+ assert len(chunks) == 3
+ chunk_uris = {t.o.iri for t in chunks}
+ assert chunk_uris == set(chunk_ids)
+
+
+class TestDocRagSynthesisTriples:
+
+ SYN_URI = "urn:trustgraph:docrag:test/synthesis"
+ EXP_URI = "urn:trustgraph:docrag:test/exploration"
+
+ def test_docrag_synthesis_types(self):
+ triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
+ assert has_type(triples, self.SYN_URI, PROV_ENTITY)
+ assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
+
+ def test_docrag_synthesis_derived_from_exploration(self):
+ """DocRAG skips the focus step — synthesis derives from exploration."""
+ triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
+ derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
+ assert derived.o.iri == self.EXP_URI
+
+ def test_docrag_synthesis_with_inline(self):
+ triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer")
+ content = find_triple(triples, TG_CONTENT, self.SYN_URI)
+ assert content.o.value == "answer"
+
+ def test_docrag_synthesis_with_document(self):
+ triples = docrag_synthesis_triples(
+ self.SYN_URI, self.EXP_URI, document_id="urn:doc:ans"
+ )
+ doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
+ assert doc.o.iri == "urn:doc:ans"
+ content = find_triple(triples, TG_CONTENT, self.SYN_URI)
+ assert content is None
diff --git a/tests/unit/test_provenance/test_uris.py b/tests/unit/test_provenance/test_uris.py
new file mode 100644
index 00000000..0e69734c
--- /dev/null
+++ b/tests/unit/test_provenance/test_uris.py
@@ -0,0 +1,292 @@
+"""
+Tests for provenance URI generation functions.
+"""
+
+import pytest
+from unittest.mock import patch
+
+from trustgraph.provenance.uris import (
+ TRUSTGRAPH_BASE,
+ _encode_id,
+ document_uri,
+ page_uri,
+ chunk_uri_from_page,
+ chunk_uri_from_doc,
+ activity_uri,
+ subgraph_uri,
+ agent_uri,
+ question_uri,
+ exploration_uri,
+ focus_uri,
+ synthesis_uri,
+ edge_selection_uri,
+ agent_session_uri,
+ agent_iteration_uri,
+ agent_final_uri,
+ docrag_question_uri,
+ docrag_exploration_uri,
+ docrag_synthesis_uri,
+)
+
+
+class TestEncodeId:
+ """Tests for the _encode_id helper."""
+
+ def test_plain_string(self):
+ assert _encode_id("abc123") == "abc123"
+
+ def test_string_with_spaces(self):
+ assert _encode_id("hello world") == "hello%20world"
+
+ def test_string_with_slashes(self):
+ assert _encode_id("a/b/c") == "a%2Fb%2Fc"
+
+ def test_integer_input(self):
+ assert _encode_id(42) == "42"
+
+ def test_empty_string(self):
+ assert _encode_id("") == ""
+
+ def test_special_characters(self):
+ result = _encode_id("name@domain.com")
+ assert "@" not in result or result == "name%40domain.com"
+
+
+class TestDocumentUris:
+ """Tests for document, page, and chunk URI generation."""
+
+ def test_document_uri_passthrough(self):
+ iri = "https://example.com/doc/123"
+ assert document_uri(iri) == iri
+
+ def test_page_uri_format(self):
+ result = page_uri("https://example.com/doc/123", 5)
+ assert result == "https://example.com/doc/123/p5"
+
+ def test_page_uri_page_zero(self):
+ result = page_uri("https://example.com/doc/123", 0)
+ assert result == "https://example.com/doc/123/p0"
+
+ def test_chunk_uri_from_page_format(self):
+ result = chunk_uri_from_page("https://example.com/doc/123", 2, 3)
+ assert result == "https://example.com/doc/123/p2/c3"
+
+ def test_chunk_uri_from_doc_format(self):
+ result = chunk_uri_from_doc("https://example.com/doc/123", 7)
+ assert result == "https://example.com/doc/123/c7"
+
+ def test_page_uri_preserves_doc_iri(self):
+ doc = "urn:isbn:978-3-16-148410-0"
+ result = page_uri(doc, 1)
+ assert result.startswith(doc)
+
+ def test_chunk_from_page_hierarchy(self):
+ """Chunk URI should contain both page and chunk identifiers."""
+ result = chunk_uri_from_page("https://example.com/doc", 3, 5)
+ assert "/p3/" in result
+ assert result.endswith("/c5")
+
+
+class TestActivityAndSubgraphUris:
+ """Tests for activity_uri, subgraph_uri, and agent_uri."""
+
+ def test_activity_uri_with_id(self):
+ result = activity_uri("my-activity-id")
+ assert result == f"{TRUSTGRAPH_BASE}/activity/my-activity-id"
+
+ def test_activity_uri_auto_generates_uuid(self):
+ result = activity_uri()
+ assert result.startswith(f"{TRUSTGRAPH_BASE}/activity/")
+ # UUID part should be non-empty
+ uuid_part = result.split("/activity/")[1]
+ assert len(uuid_part) > 0
+
+ def test_activity_uri_unique_uuids(self):
+ r1 = activity_uri()
+ r2 = activity_uri()
+ assert r1 != r2
+
+ def test_activity_uri_encodes_special_chars(self):
+ result = activity_uri("id with spaces")
+ assert "id%20with%20spaces" in result
+
+ def test_subgraph_uri_with_id(self):
+ result = subgraph_uri("sg-123")
+ assert result == f"{TRUSTGRAPH_BASE}/subgraph/sg-123"
+
+ def test_subgraph_uri_auto_generates_uuid(self):
+ result = subgraph_uri()
+ assert result.startswith(f"{TRUSTGRAPH_BASE}/subgraph/")
+ uuid_part = result.split("/subgraph/")[1]
+ assert len(uuid_part) > 0
+
+ def test_subgraph_uri_unique_uuids(self):
+ r1 = subgraph_uri()
+ r2 = subgraph_uri()
+ assert r1 != r2
+
+ def test_agent_uri_format(self):
+ result = agent_uri("pdf-extractor")
+ assert result == f"{TRUSTGRAPH_BASE}/agent/pdf-extractor"
+
+ def test_agent_uri_encodes_special_chars(self):
+ result = agent_uri("my component")
+ assert "my%20component" in result
+
+
+class TestGraphRagQueryUris:
+ """Tests for GraphRAG query-time provenance URIs."""
+
+ FIXED_UUID = "550e8400-e29b-41d4-a716-446655440000"
+
+ def test_question_uri_with_session_id(self):
+ result = question_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:question:{self.FIXED_UUID}"
+
+ def test_question_uri_auto_generates(self):
+ result = question_uri()
+ assert result.startswith("urn:trustgraph:question:")
+ uuid_part = result.split("urn:trustgraph:question:")[1]
+ assert len(uuid_part) > 0
+
+ def test_question_uri_unique(self):
+ r1 = question_uri()
+ r2 = question_uri()
+ assert r1 != r2
+
+ def test_exploration_uri_format(self):
+ result = exploration_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:prov:exploration:{self.FIXED_UUID}"
+
+ def test_focus_uri_format(self):
+ result = focus_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:prov:focus:{self.FIXED_UUID}"
+
+ def test_synthesis_uri_format(self):
+ result = synthesis_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:prov:synthesis:{self.FIXED_UUID}"
+
+ def test_edge_selection_uri_format(self):
+ result = edge_selection_uri(self.FIXED_UUID, 3)
+ assert result == f"urn:trustgraph:prov:edge:{self.FIXED_UUID}:3"
+
+ def test_edge_selection_uri_zero_index(self):
+ result = edge_selection_uri(self.FIXED_UUID, 0)
+ assert result.endswith(":0")
+
+ def test_session_uris_share_session_id(self):
+ """All URIs for a session should contain the same session ID."""
+ sid = self.FIXED_UUID
+ q = question_uri(sid)
+ e = exploration_uri(sid)
+ f = focus_uri(sid)
+ s = synthesis_uri(sid)
+ for uri in [q, e, f, s]:
+ assert sid in uri
+
+
+class TestAgentProvenanceUris:
+ """Tests for agent provenance URIs."""
+
+ FIXED_UUID = "661e8400-e29b-41d4-a716-446655440000"
+
+ def test_agent_session_uri_with_id(self):
+ result = agent_session_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}"
+
+ def test_agent_session_uri_auto_generates(self):
+ result = agent_session_uri()
+ assert result.startswith("urn:trustgraph:agent:")
+
+ def test_agent_session_uri_unique(self):
+ r1 = agent_session_uri()
+ r2 = agent_session_uri()
+ assert r1 != r2
+
+ def test_agent_iteration_uri_format(self):
+ result = agent_iteration_uri(self.FIXED_UUID, 1)
+ assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/i1"
+
+ def test_agent_iteration_uri_numbering(self):
+ r1 = agent_iteration_uri(self.FIXED_UUID, 1)
+ r2 = agent_iteration_uri(self.FIXED_UUID, 2)
+ assert r1 != r2
+ assert r1.endswith("/i1")
+ assert r2.endswith("/i2")
+
+ def test_agent_final_uri_format(self):
+ result = agent_final_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/final"
+
+ def test_agent_uris_share_session_id(self):
+ sid = self.FIXED_UUID
+ session = agent_session_uri(sid)
+ iteration = agent_iteration_uri(sid, 1)
+ final = agent_final_uri(sid)
+ for uri in [session, iteration, final]:
+ assert sid in uri
+
+
+class TestDocRagProvenanceUris:
+ """Tests for Document RAG provenance URIs."""
+
+ FIXED_UUID = "772e8400-e29b-41d4-a716-446655440000"
+
+ def test_docrag_question_uri_with_id(self):
+ result = docrag_question_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}"
+
+ def test_docrag_question_uri_auto_generates(self):
+ result = docrag_question_uri()
+ assert result.startswith("urn:trustgraph:docrag:")
+
+ def test_docrag_question_uri_unique(self):
+ r1 = docrag_question_uri()
+ r2 = docrag_question_uri()
+ assert r1 != r2
+
+ def test_docrag_exploration_uri_format(self):
+ result = docrag_exploration_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/exploration"
+
+ def test_docrag_synthesis_uri_format(self):
+ result = docrag_synthesis_uri(self.FIXED_UUID)
+ assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/synthesis"
+
+ def test_docrag_uris_share_session_id(self):
+ sid = self.FIXED_UUID
+ q = docrag_question_uri(sid)
+ e = docrag_exploration_uri(sid)
+ s = docrag_synthesis_uri(sid)
+ for uri in [q, e, s]:
+ assert sid in uri
+
+
+class TestUriNamespaceIsolation:
+ """Verify that different provenance types use distinct URI namespaces."""
+
+ FIXED_UUID = "883e8400-e29b-41d4-a716-446655440000"
+
+ def test_graphrag_vs_agent_namespace(self):
+ graphrag = question_uri(self.FIXED_UUID)
+ agent = agent_session_uri(self.FIXED_UUID)
+ assert graphrag != agent
+ assert "question" in graphrag
+ assert "agent" in agent
+
+ def test_graphrag_vs_docrag_namespace(self):
+ graphrag = question_uri(self.FIXED_UUID)
+ docrag = docrag_question_uri(self.FIXED_UUID)
+ assert graphrag != docrag
+
+ def test_agent_vs_docrag_namespace(self):
+ agent = agent_session_uri(self.FIXED_UUID)
+ docrag = docrag_question_uri(self.FIXED_UUID)
+ assert agent != docrag
+
+ def test_extraction_vs_query_namespace(self):
+ """Extraction URIs use https://, query URIs use urn:."""
+ ext = activity_uri(self.FIXED_UUID)
+ query = question_uri(self.FIXED_UUID)
+ assert ext.startswith("https://")
+ assert query.startswith("urn:")
diff --git a/tests/unit/test_provenance/test_vocabulary.py b/tests/unit/test_provenance/test_vocabulary.py
new file mode 100644
index 00000000..a3c644e8
--- /dev/null
+++ b/tests/unit/test_provenance/test_vocabulary.py
@@ -0,0 +1,124 @@
+"""
+Tests for provenance vocabulary bootstrap.
+"""
+
+import pytest
+
+from trustgraph.schema import Triple, Term, IRI, LITERAL
+
+from trustgraph.provenance.vocabulary import (
+ get_vocabulary_triples,
+ PROV_CLASS_LABELS,
+ PROV_PREDICATE_LABELS,
+ DC_PREDICATE_LABELS,
+ SCHEMA_LABELS,
+ SKOS_LABELS,
+ TG_CLASS_LABELS,
+ TG_PREDICATE_LABELS,
+)
+
+from trustgraph.provenance.namespaces import (
+ RDFS_LABEL,
+ PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT,
+ PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
+ PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME,
+ DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR,
+ TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
+)
+
+
+class TestVocabularyTriples:
+ """Tests for the vocabulary bootstrap function."""
+
+ def test_returns_list_of_triples(self):
+ result = get_vocabulary_triples()
+ assert isinstance(result, list)
+ assert len(result) > 0
+ for t in result:
+ assert isinstance(t, Triple)
+
+ def test_all_triples_are_label_triples(self):
+ """Every vocabulary triple should use rdfs:label as predicate."""
+ for t in get_vocabulary_triples():
+ assert t.p.type == IRI
+ assert t.p.iri == RDFS_LABEL
+
+ def test_all_subjects_are_iris(self):
+ for t in get_vocabulary_triples():
+ assert t.s.type == IRI
+ assert len(t.s.iri) > 0
+
+ def test_all_objects_are_literals(self):
+ for t in get_vocabulary_triples():
+ assert t.o.type == LITERAL
+ assert len(t.o.value) > 0
+
+ def test_no_duplicate_subjects(self):
+ subjects = [t.s.iri for t in get_vocabulary_triples()]
+ assert len(subjects) == len(set(subjects))
+
+ def test_includes_prov_classes(self):
+ subjects = {t.s.iri for t in get_vocabulary_triples()}
+ assert PROV_ENTITY in subjects
+ assert PROV_ACTIVITY in subjects
+ assert PROV_AGENT in subjects
+
+ def test_includes_prov_predicates(self):
+ subjects = {t.s.iri for t in get_vocabulary_triples()}
+ assert PROV_WAS_DERIVED_FROM in subjects
+ assert PROV_WAS_GENERATED_BY in subjects
+ assert PROV_USED in subjects
+ assert PROV_WAS_ASSOCIATED_WITH in subjects
+ assert PROV_STARTED_AT_TIME in subjects
+
+ def test_includes_dc_predicates(self):
+ subjects = {t.s.iri for t in get_vocabulary_triples()}
+ assert DC_TITLE in subjects
+ assert DC_SOURCE in subjects
+ assert DC_DATE in subjects
+ assert DC_CREATOR in subjects
+
+ def test_includes_tg_classes(self):
+ subjects = {t.s.iri for t in get_vocabulary_triples()}
+ assert TG_DOCUMENT_TYPE in subjects
+ assert TG_PAGE_TYPE in subjects
+ assert TG_CHUNK_TYPE in subjects
+ assert TG_SUBGRAPH_TYPE in subjects
+
+ def test_component_lists_sum_to_total(self):
+ total = get_vocabulary_triples()
+ components = (
+ PROV_CLASS_LABELS +
+ PROV_PREDICATE_LABELS +
+ DC_PREDICATE_LABELS +
+ SCHEMA_LABELS +
+ SKOS_LABELS +
+ TG_CLASS_LABELS +
+ TG_PREDICATE_LABELS
+ )
+ assert len(total) == len(components)
+
+ def test_idempotent(self):
+ """Calling twice should return equivalent triples."""
+ r1 = get_vocabulary_triples()
+ r2 = get_vocabulary_triples()
+ assert len(r1) == len(r2)
+ for t1, t2 in zip(r1, r2):
+ assert t1.s.iri == t2.s.iri
+ assert t1.o.value == t2.o.value
+
+
+class TestNamespaceConstants:
+ """Verify namespace constants are well-formed IRIs."""
+
+ def test_prov_namespace_prefix(self):
+ assert PROV_ENTITY.startswith("http://www.w3.org/ns/prov#")
+
+ def test_dc_namespace_prefix(self):
+ assert DC_TITLE.startswith("http://purl.org/dc/elements/1.1/")
+
+ def test_tg_namespace_prefix(self):
+ assert TG_DOCUMENT_TYPE.startswith("https://trustgraph.ai/ns/")
+
+ def test_rdfs_label_iri(self):
+ assert RDFS_LABEL == "http://www.w3.org/2000/01/rdf-schema#label"
diff --git a/tests/unit/test_rdf/__init__.py b/tests/unit/test_rdf/__init__.py
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/tests/unit/test_rdf/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/unit/test_rdf/test_rdf_primitives.py b/tests/unit/test_rdf/test_rdf_primitives.py
new file mode 100644
index 00000000..2498677b
--- /dev/null
+++ b/tests/unit/test_rdf/test_rdf_primitives.py
@@ -0,0 +1,309 @@
+"""
+Tests for RDF 1.2 type system primitives: Term dataclass (IRI, blank node,
+typed literal, language-tagged literal, quoted triple), Triple/Quad dataclass
+with named graph support, and the knowledge/defs helper types.
+"""
+
+import pytest
+
+from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE
+
+
+# ---------------------------------------------------------------------------
+# Type constants
+# ---------------------------------------------------------------------------
+
+class TestTypeConstants:
+
+ def test_iri_constant(self):
+ assert IRI == "i"
+
+ def test_blank_constant(self):
+ assert BLANK == "b"
+
+ def test_literal_constant(self):
+ assert LITERAL == "l"
+
+ def test_triple_constant(self):
+ assert TRIPLE == "t"
+
+ def test_constants_are_distinct(self):
+ vals = {IRI, BLANK, LITERAL, TRIPLE}
+ assert len(vals) == 4
+
+
+# ---------------------------------------------------------------------------
+# IRI terms
+# ---------------------------------------------------------------------------
+
+class TestIriTerm:
+
+ def test_create_iri(self):
+ t = Term(type=IRI, iri="http://example.org/Alice")
+ assert t.type == IRI
+ assert t.iri == "http://example.org/Alice"
+
+ def test_iri_defaults_empty(self):
+ t = Term(type=IRI)
+ assert t.iri == ""
+
+ def test_iri_with_fragment(self):
+ t = Term(type=IRI, iri="http://example.org/ontology#Person")
+ assert "#Person" in t.iri
+
+ def test_iri_with_unicode(self):
+ t = Term(type=IRI, iri="http://example.org/概念")
+ assert "概念" in t.iri
+
+ def test_iri_other_fields_default(self):
+ t = Term(type=IRI, iri="http://example.org/x")
+ assert t.id == ""
+ assert t.value == ""
+ assert t.datatype == ""
+ assert t.language == ""
+ assert t.triple is None
+
+
+# ---------------------------------------------------------------------------
+# Blank node terms
+# ---------------------------------------------------------------------------
+
+class TestBlankNodeTerm:
+
+ def test_create_blank_node(self):
+ t = Term(type=BLANK, id="_:b0")
+ assert t.type == BLANK
+ assert t.id == "_:b0"
+
+ def test_blank_node_defaults_empty(self):
+ t = Term(type=BLANK)
+ assert t.id == ""
+
+ def test_blank_node_arbitrary_id(self):
+ t = Term(type=BLANK, id="node-abc-123")
+ assert t.id == "node-abc-123"
+
+
+# ---------------------------------------------------------------------------
+# Typed literals (XSD datatypes)
+# ---------------------------------------------------------------------------
+
+class TestTypedLiteral:
+
+ def test_plain_literal(self):
+ t = Term(type=LITERAL, value="hello")
+ assert t.type == LITERAL
+ assert t.value == "hello"
+ assert t.datatype == ""
+ assert t.language == ""
+
+ def test_xsd_integer(self):
+ t = Term(
+ type=LITERAL, value="42",
+ datatype="http://www.w3.org/2001/XMLSchema#integer",
+ )
+ assert t.value == "42"
+ assert "integer" in t.datatype
+
+ def test_xsd_boolean(self):
+ t = Term(
+ type=LITERAL, value="true",
+ datatype="http://www.w3.org/2001/XMLSchema#boolean",
+ )
+ assert t.datatype.endswith("#boolean")
+
+ def test_xsd_date(self):
+ t = Term(
+ type=LITERAL, value="2026-03-13",
+ datatype="http://www.w3.org/2001/XMLSchema#date",
+ )
+ assert t.value == "2026-03-13"
+ assert t.datatype.endswith("#date")
+
+ def test_xsd_double(self):
+ t = Term(
+ type=LITERAL, value="3.14",
+ datatype="http://www.w3.org/2001/XMLSchema#double",
+ )
+ assert t.datatype.endswith("#double")
+
+ def test_empty_value_literal(self):
+ t = Term(type=LITERAL, value="")
+ assert t.value == ""
+
+
+# ---------------------------------------------------------------------------
+# Language-tagged literals
+# ---------------------------------------------------------------------------
+
+class TestLanguageTaggedLiteral:
+
+ def test_english_tag(self):
+ t = Term(type=LITERAL, value="hello", language="en")
+ assert t.language == "en"
+ assert t.datatype == ""
+
+ def test_french_tag(self):
+ t = Term(type=LITERAL, value="bonjour", language="fr")
+ assert t.language == "fr"
+
+ def test_bcp47_subtag(self):
+ t = Term(type=LITERAL, value="colour", language="en-GB")
+ assert t.language == "en-GB"
+
+ def test_language_and_datatype_mutually_exclusive(self):
+ """Both can be set on the dataclass, but semantically only one should be used."""
+ t = Term(type=LITERAL, value="x", language="en",
+ datatype="http://www.w3.org/2001/XMLSchema#string")
+ # Dataclass allows both — translators should respect mutual exclusivity
+ assert t.language == "en"
+ assert t.datatype != ""
+
+
+# ---------------------------------------------------------------------------
+# Quoted triples (RDF-star)
+# ---------------------------------------------------------------------------
+
+class TestQuotedTriple:
+
+ def test_term_with_nested_triple(self):
+ inner = Triple(
+ s=Term(type=IRI, iri="http://example.org/Alice"),
+ p=Term(type=IRI, iri="http://xmlns.com/foaf/0.1/knows"),
+ o=Term(type=IRI, iri="http://example.org/Bob"),
+ )
+ qt = Term(type=TRIPLE, triple=inner)
+ assert qt.type == TRIPLE
+ assert qt.triple is inner
+ assert qt.triple.s.iri == "http://example.org/Alice"
+
+ def test_quoted_triple_as_object(self):
+ """A triple whose object is a quoted triple (RDF-star)."""
+ inner = Triple(
+ s=Term(type=IRI, iri="http://example.org/Hope"),
+ p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
+ o=Term(type=LITERAL, value="A feeling of expectation"),
+ )
+ outer = Triple(
+ s=Term(type=IRI, iri="urn:subgraph:123"),
+ p=Term(type=IRI, iri="http://trustgraph.ai/tg/contains"),
+ o=Term(type=TRIPLE, triple=inner),
+ )
+ assert outer.o.type == TRIPLE
+ assert outer.o.triple.o.value == "A feeling of expectation"
+
+ def test_quoted_triple_none(self):
+ t = Term(type=TRIPLE, triple=None)
+ assert t.triple is None
+
+
+# ---------------------------------------------------------------------------
+# Triple / Quad (named graph)
+# ---------------------------------------------------------------------------
+
+class TestTripleQuad:
+
+ def test_default_graph_is_none(self):
+ t = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="val"),
+ )
+ assert t.g is None
+
+ def test_named_graph(self):
+ t = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="val"),
+ g="urn:graph:source",
+ )
+ assert t.g == "urn:graph:source"
+
+ def test_empty_string_graph(self):
+ t = Triple(g="")
+ assert t.g == ""
+
+ def test_triple_with_all_none_terms(self):
+ t = Triple()
+ assert t.s is None
+ assert t.p is None
+ assert t.o is None
+ assert t.g is None
+
+ def test_triple_equality(self):
+ """Dataclass equality based on field values."""
+ t1 = Triple(
+ s=Term(type=IRI, iri="http://example.org/A"),
+ p=Term(type=IRI, iri="http://example.org/B"),
+ o=Term(type=LITERAL, value="C"),
+ )
+ t2 = Triple(
+ s=Term(type=IRI, iri="http://example.org/A"),
+ p=Term(type=IRI, iri="http://example.org/B"),
+ o=Term(type=LITERAL, value="C"),
+ )
+ assert t1 == t2
+
+
+# ---------------------------------------------------------------------------
+# knowledge/defs helper types
+# ---------------------------------------------------------------------------
+
+class TestKnowledgeDefs:
+
+ def test_uri_type(self):
+ from trustgraph.knowledge.defs import Uri
+ u = Uri("http://example.org/x")
+ assert u.is_uri() is True
+ assert u.is_literal() is False
+ assert u.is_triple() is False
+ assert str(u) == "http://example.org/x"
+
+ def test_literal_type(self):
+ from trustgraph.knowledge.defs import Literal
+ l = Literal("hello world")
+ assert l.is_uri() is False
+ assert l.is_literal() is True
+ assert l.is_triple() is False
+ assert str(l) == "hello world"
+
+ def test_quoted_triple_type(self):
+ from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
+ qt = QuotedTriple(
+ s=Uri("http://example.org/s"),
+ p=Uri("http://example.org/p"),
+ o=Literal("val"),
+ )
+ assert qt.is_uri() is False
+ assert qt.is_literal() is False
+ assert qt.is_triple() is True
+ assert qt.s == "http://example.org/s"
+ assert qt.o == "val"
+
+ def test_quoted_triple_repr(self):
+ from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
+ qt = QuotedTriple(
+ s=Uri("http://example.org/A"),
+ p=Uri("http://example.org/B"),
+ o=Literal("C"),
+ )
+ r = repr(qt)
+ assert "<<" in r
+ assert ">>" in r
+ assert "http://example.org/A" in r
+
+ def test_quoted_triple_nested(self):
+ """QuotedTriple can contain another QuotedTriple as object."""
+ from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
+ inner = QuotedTriple(
+ s=Uri("http://example.org/s"),
+ p=Uri("http://example.org/p"),
+ o=Literal("v"),
+ )
+ outer = QuotedTriple(
+ s=Uri("http://example.org/s2"),
+ p=Uri("http://example.org/p2"),
+ o=inner,
+ )
+ assert outer.o.is_triple() is True
diff --git a/tests/unit/test_rdf/test_rdf_storage_helpers.py b/tests/unit/test_rdf/test_rdf_storage_helpers.py
new file mode 100644
index 00000000..7bd1807a
--- /dev/null
+++ b/tests/unit/test_rdf/test_rdf_storage_helpers.py
@@ -0,0 +1,217 @@
+"""
+Tests for RDF storage helper functions used by the Cassandra triple writer:
+serialize_triple, get_term_value, get_term_otype, get_term_dtype, get_term_lang.
+"""
+
+import json
+import pytest
+
+from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE
+from trustgraph.storage.triples.cassandra.write import (
+ serialize_triple,
+ get_term_value,
+ get_term_otype,
+ get_term_dtype,
+ get_term_lang,
+)
+
+
+# ---------------------------------------------------------------------------
+# get_term_otype — maps Term.type to storage object type code
+# ---------------------------------------------------------------------------
+
+class TestGetTermOtype:
+
+ def test_iri_maps_to_u(self):
+ assert get_term_otype(Term(type=IRI, iri="http://x")) == "u"
+
+ def test_blank_maps_to_u(self):
+ assert get_term_otype(Term(type=BLANK, id="_:b0")) == "u"
+
+ def test_literal_maps_to_l(self):
+ assert get_term_otype(Term(type=LITERAL, value="x")) == "l"
+
+ def test_triple_maps_to_t(self):
+ assert get_term_otype(Term(type=TRIPLE)) == "t"
+
+ def test_none_defaults_to_u(self):
+ assert get_term_otype(None) == "u"
+
+ def test_unknown_type_defaults_to_u(self):
+ assert get_term_otype(Term(type="z")) == "u"
+
+
+# ---------------------------------------------------------------------------
+# get_term_dtype — extracts XSD datatype from literals
+# ---------------------------------------------------------------------------
+
+class TestGetTermDtype:
+
+ def test_literal_with_datatype(self):
+ t = Term(type=LITERAL, value="42",
+ datatype="http://www.w3.org/2001/XMLSchema#integer")
+ assert get_term_dtype(t) == "http://www.w3.org/2001/XMLSchema#integer"
+
+ def test_literal_without_datatype(self):
+ t = Term(type=LITERAL, value="hello")
+ assert get_term_dtype(t) == ""
+
+ def test_iri_returns_empty(self):
+ assert get_term_dtype(Term(type=IRI, iri="http://x")) == ""
+
+ def test_none_returns_empty(self):
+ assert get_term_dtype(None) == ""
+
+
+# ---------------------------------------------------------------------------
+# get_term_lang — extracts language tag from literals
+# ---------------------------------------------------------------------------
+
+class TestGetTermLang:
+
+ def test_literal_with_language(self):
+ t = Term(type=LITERAL, value="bonjour", language="fr")
+ assert get_term_lang(t) == "fr"
+
+ def test_literal_without_language(self):
+ t = Term(type=LITERAL, value="hello")
+ assert get_term_lang(t) == ""
+
+ def test_iri_returns_empty(self):
+ assert get_term_lang(Term(type=IRI, iri="http://x")) == ""
+
+ def test_none_returns_empty(self):
+ assert get_term_lang(None) == ""
+
+ def test_bcp47_subtag_preserved(self):
+ t = Term(type=LITERAL, value="colour", language="en-GB")
+ assert get_term_lang(t) == "en-GB"
+
+
+# ---------------------------------------------------------------------------
+# get_term_value — extracts string value from any Term
+# ---------------------------------------------------------------------------
+
+class TestGetTermValue:
+
+ def test_iri_returns_iri(self):
+ t = Term(type=IRI, iri="http://example.org/Alice")
+ assert get_term_value(t) == "http://example.org/Alice"
+
+ def test_literal_returns_value(self):
+ t = Term(type=LITERAL, value="hello")
+ assert get_term_value(t) == "hello"
+
+ def test_blank_returns_id(self):
+ t = Term(type=BLANK, id="_:b0")
+ assert get_term_value(t) == "_:b0"
+
+ def test_none_returns_none(self):
+ assert get_term_value(None) is None
+
+ def test_triple_returns_serialized_json(self):
+ inner = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="val"),
+ )
+ t = Term(type=TRIPLE, triple=inner)
+ result = get_term_value(t)
+ parsed = json.loads(result)
+ assert parsed["s"]["type"] == "i"
+ assert parsed["s"]["iri"] == "http://example.org/s"
+ assert parsed["o"]["value"] == "val"
+
+
+# ---------------------------------------------------------------------------
+# serialize_triple — full Triple → JSON serialization
+# ---------------------------------------------------------------------------
+
+class TestSerializeTriple:
+
+ def test_serialize_iri_triple(self):
+ t = Triple(
+ s=Term(type=IRI, iri="http://example.org/A"),
+ p=Term(type=IRI, iri="http://example.org/rel"),
+ o=Term(type=IRI, iri="http://example.org/B"),
+ )
+ result = json.loads(serialize_triple(t))
+ assert result["s"]["type"] == "i"
+ assert result["s"]["iri"] == "http://example.org/A"
+ assert result["p"]["iri"] == "http://example.org/rel"
+ assert result["o"]["type"] == "i"
+
+ def test_serialize_literal_object(self):
+ t = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="hello"),
+ )
+ result = json.loads(serialize_triple(t))
+ assert result["o"]["type"] == "l"
+ assert result["o"]["value"] == "hello"
+
+ def test_serialize_typed_literal(self):
+ t = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="42",
+ datatype="http://www.w3.org/2001/XMLSchema#integer"),
+ )
+ result = json.loads(serialize_triple(t))
+ assert result["o"]["datatype"] == "http://www.w3.org/2001/XMLSchema#integer"
+
+ def test_serialize_language_tagged_literal(self):
+ t = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="bonjour", language="fr"),
+ )
+ result = json.loads(serialize_triple(t))
+ assert result["o"]["language"] == "fr"
+
+ def test_serialize_blank_node(self):
+ t = Triple(
+ s=Term(type=BLANK, id="_:b0"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="v"),
+ )
+ result = json.loads(serialize_triple(t))
+ assert result["s"]["type"] == "b"
+ assert result["s"]["id"] == "_:b0"
+
+ def test_serialize_nested_quoted_triple(self):
+ inner = Triple(
+ s=Term(type=IRI, iri="http://example.org/inner-s"),
+ p=Term(type=IRI, iri="http://example.org/inner-p"),
+ o=Term(type=LITERAL, value="inner-val"),
+ )
+ outer = Triple(
+ s=Term(type=IRI, iri="http://example.org/outer-s"),
+ p=Term(type=IRI, iri="http://example.org/outer-p"),
+ o=Term(type=TRIPLE, triple=inner),
+ )
+ result = json.loads(serialize_triple(outer))
+ nested = json.loads(result["o"]["triple"])
+ assert nested["s"]["iri"] == "http://example.org/inner-s"
+ assert nested["o"]["value"] == "inner-val"
+
+ def test_serialize_none_returns_none(self):
+ assert serialize_triple(None) is None
+
+ def test_serialize_none_terms(self):
+ t = Triple(s=None, p=None, o=None)
+ result = json.loads(serialize_triple(t))
+ assert result["s"] is None
+ assert result["p"] is None
+ assert result["o"] is None
+
+ def test_serialize_plain_literal_omits_datatype_and_language(self):
+ t = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="plain"),
+ )
+ result = json.loads(serialize_triple(t))
+ assert "datatype" not in result["o"]
+ assert "language" not in result["o"]
diff --git a/tests/unit/test_rdf/test_rdf_wire_format.py b/tests/unit/test_rdf/test_rdf_wire_format.py
new file mode 100644
index 00000000..a0bbd27a
--- /dev/null
+++ b/tests/unit/test_rdf/test_rdf_wire_format.py
@@ -0,0 +1,357 @@
+"""
+Tests for RDF wire format translators: TermTranslator and TripleTranslator
+round-trip encoding for all RDF 1.2 term types (IRI, blank node, typed literal,
+language-tagged literal, quoted triple) and named graph quads.
+"""
+
+import pytest
+
+from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE
+from trustgraph.messaging.translators.primitives import (
+ TermTranslator, TripleTranslator, SubgraphTranslator,
+)
+
+
+@pytest.fixture
+def term_tx():
+ return TermTranslator()
+
+
+@pytest.fixture
+def triple_tx():
+ return TripleTranslator()
+
+
+# ---------------------------------------------------------------------------
+# TermTranslator — IRI
+# ---------------------------------------------------------------------------
+
+class TestTermTranslatorIri:
+
+ def test_iri_to_pulsar(self, term_tx):
+ data = {"t": "i", "i": "http://example.org/Alice"}
+ term = term_tx.to_pulsar(data)
+ assert term.type == IRI
+ assert term.iri == "http://example.org/Alice"
+
+ def test_iri_from_pulsar(self, term_tx):
+ term = Term(type=IRI, iri="http://example.org/Bob")
+ wire = term_tx.from_pulsar(term)
+ assert wire == {"t": "i", "i": "http://example.org/Bob"}
+
+ def test_iri_round_trip(self, term_tx):
+ original = Term(type=IRI, iri="http://example.org/round")
+ wire = term_tx.from_pulsar(original)
+ restored = term_tx.to_pulsar(wire)
+ assert restored == original
+
+
+# ---------------------------------------------------------------------------
+# TermTranslator — Blank node
+# ---------------------------------------------------------------------------
+
+class TestTermTranslatorBlank:
+
+ def test_blank_to_pulsar(self, term_tx):
+ data = {"t": "b", "d": "_:b42"}
+ term = term_tx.to_pulsar(data)
+ assert term.type == BLANK
+ assert term.id == "_:b42"
+
+ def test_blank_from_pulsar(self, term_tx):
+ term = Term(type=BLANK, id="_:node1")
+ wire = term_tx.from_pulsar(term)
+ assert wire == {"t": "b", "d": "_:node1"}
+
+ def test_blank_round_trip(self, term_tx):
+ original = Term(type=BLANK, id="_:x")
+ wire = term_tx.from_pulsar(original)
+ restored = term_tx.to_pulsar(wire)
+ assert restored == original
+
+
+# ---------------------------------------------------------------------------
+# TermTranslator — Typed literal (XSD)
+# ---------------------------------------------------------------------------
+
+class TestTermTranslatorTypedLiteral:
+
+ def test_plain_literal_to_pulsar(self, term_tx):
+ data = {"t": "l", "v": "hello"}
+ term = term_tx.to_pulsar(data)
+ assert term.type == LITERAL
+ assert term.value == "hello"
+ assert term.datatype == ""
+ assert term.language == ""
+
+ def test_xsd_integer_to_pulsar(self, term_tx):
+ data = {
+ "t": "l", "v": "42",
+ "dt": "http://www.w3.org/2001/XMLSchema#integer",
+ }
+ term = term_tx.to_pulsar(data)
+ assert term.value == "42"
+ assert term.datatype.endswith("#integer")
+
+ def test_typed_literal_from_pulsar(self, term_tx):
+ term = Term(
+ type=LITERAL, value="3.14",
+ datatype="http://www.w3.org/2001/XMLSchema#double",
+ )
+ wire = term_tx.from_pulsar(term)
+ assert wire["t"] == "l"
+ assert wire["v"] == "3.14"
+ assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double"
+ assert "ln" not in wire # No language tag
+
+ def test_typed_literal_round_trip(self, term_tx):
+ original = Term(
+ type=LITERAL, value="true",
+ datatype="http://www.w3.org/2001/XMLSchema#boolean",
+ )
+ wire = term_tx.from_pulsar(original)
+ restored = term_tx.to_pulsar(wire)
+ assert restored == original
+
+ def test_plain_literal_omits_dt_and_ln(self, term_tx):
+ term = Term(type=LITERAL, value="x")
+ wire = term_tx.from_pulsar(term)
+ assert "dt" not in wire
+ assert "ln" not in wire
+
+
+# ---------------------------------------------------------------------------
+# TermTranslator — Language-tagged literal
+# ---------------------------------------------------------------------------
+
+class TestTermTranslatorLangLiteral:
+
+ def test_language_tag_to_pulsar(self, term_tx):
+ data = {"t": "l", "v": "bonjour", "ln": "fr"}
+ term = term_tx.to_pulsar(data)
+ assert term.value == "bonjour"
+ assert term.language == "fr"
+
+ def test_language_tag_from_pulsar(self, term_tx):
+ term = Term(type=LITERAL, value="colour", language="en-GB")
+ wire = term_tx.from_pulsar(term)
+ assert wire["ln"] == "en-GB"
+ assert "dt" not in wire # No datatype
+
+ def test_language_tag_round_trip(self, term_tx):
+ original = Term(type=LITERAL, value="hola", language="es")
+ wire = term_tx.from_pulsar(original)
+ restored = term_tx.to_pulsar(wire)
+ assert restored == original
+
+
+# ---------------------------------------------------------------------------
+# TermTranslator — Quoted triple (RDF-star)
+# ---------------------------------------------------------------------------
+
+class TestTermTranslatorQuotedTriple:
+
+ def test_quoted_triple_to_pulsar(self, term_tx):
+ data = {
+ "t": "t",
+ "tr": {
+ "s": {"t": "i", "i": "http://example.org/Alice"},
+ "p": {"t": "i", "i": "http://xmlns.com/foaf/0.1/knows"},
+ "o": {"t": "i", "i": "http://example.org/Bob"},
+ },
+ }
+ term = term_tx.to_pulsar(data)
+ assert term.type == TRIPLE
+ assert term.triple is not None
+ assert term.triple.s.iri == "http://example.org/Alice"
+ assert term.triple.o.iri == "http://example.org/Bob"
+
+ def test_quoted_triple_from_pulsar(self, term_tx):
+ inner = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="val"),
+ )
+ term = Term(type=TRIPLE, triple=inner)
+ wire = term_tx.from_pulsar(term)
+ assert wire["t"] == "t"
+ assert "tr" in wire
+ assert wire["tr"]["s"]["i"] == "http://example.org/s"
+ assert wire["tr"]["o"]["v"] == "val"
+
+ def test_quoted_triple_round_trip(self, term_tx):
+ inner = Triple(
+ s=Term(type=IRI, iri="http://example.org/A"),
+ p=Term(type=IRI, iri="http://example.org/B"),
+ o=Term(type=LITERAL, value="C", language="en"),
+ )
+ original = Term(type=TRIPLE, triple=inner)
+ wire = term_tx.from_pulsar(original)
+ restored = term_tx.to_pulsar(wire)
+ assert restored.type == TRIPLE
+ assert restored.triple.s == original.triple.s
+ assert restored.triple.o == original.triple.o
+
+ def test_quoted_triple_none_triple(self, term_tx):
+ term = Term(type=TRIPLE, triple=None)
+ wire = term_tx.from_pulsar(term)
+ assert wire == {"t": "t"}
+ # And back
+ restored = term_tx.to_pulsar(wire)
+ assert restored.type == TRIPLE
+ assert restored.triple is None
+
+ def test_quoted_triple_with_literal_object(self, term_tx):
+ data = {
+ "t": "t",
+ "tr": {
+ "s": {"t": "i", "i": "http://example.org/Hope"},
+ "p": {"t": "i", "i": "http://www.w3.org/2004/02/skos/core#definition"},
+ "o": {"t": "l", "v": "A feeling of expectation"},
+ },
+ }
+ term = term_tx.to_pulsar(data)
+ assert term.triple.o.type == LITERAL
+ assert term.triple.o.value == "A feeling of expectation"
+
+
+# ---------------------------------------------------------------------------
+# TermTranslator — Edge cases
+# ---------------------------------------------------------------------------
+
+class TestTermTranslatorEdgeCases:
+
+ def test_unknown_type(self, term_tx):
+ data = {"t": "z"}
+ term = term_tx.to_pulsar(data)
+ assert term.type == "z"
+
+ def test_empty_type(self, term_tx):
+ data = {}
+ term = term_tx.to_pulsar(data)
+ assert term.type == ""
+
+ def test_missing_iri_field(self, term_tx):
+ data = {"t": "i"}
+ term = term_tx.to_pulsar(data)
+ assert term.iri == ""
+
+ def test_missing_literal_fields(self, term_tx):
+ data = {"t": "l"}
+ term = term_tx.to_pulsar(data)
+ assert term.value == ""
+ assert term.datatype == ""
+ assert term.language == ""
+
+
+# ---------------------------------------------------------------------------
+# TripleTranslator
+# ---------------------------------------------------------------------------
+
+class TestTripleTranslator:
+
+ def test_triple_to_pulsar(self, triple_tx):
+ data = {
+ "s": {"t": "i", "i": "http://example.org/s"},
+ "p": {"t": "i", "i": "http://example.org/p"},
+ "o": {"t": "l", "v": "object"},
+ }
+ triple = triple_tx.to_pulsar(data)
+ assert triple.s.iri == "http://example.org/s"
+ assert triple.o.value == "object"
+ assert triple.g is None
+
+ def test_triple_from_pulsar(self, triple_tx):
+ triple = Triple(
+ s=Term(type=IRI, iri="http://example.org/A"),
+ p=Term(type=IRI, iri="http://example.org/B"),
+ o=Term(type=LITERAL, value="C"),
+ )
+ wire = triple_tx.from_pulsar(triple)
+ assert wire["s"]["t"] == "i"
+ assert wire["o"]["v"] == "C"
+ assert "g" not in wire
+
+ def test_quad_with_named_graph(self, triple_tx):
+ data = {
+ "s": {"t": "i", "i": "http://example.org/s"},
+ "p": {"t": "i", "i": "http://example.org/p"},
+ "o": {"t": "l", "v": "val"},
+ "g": "urn:graph:source",
+ }
+ quad = triple_tx.to_pulsar(data)
+ assert quad.g == "urn:graph:source"
+
+ def test_quad_from_pulsar_includes_graph(self, triple_tx):
+ quad = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="v"),
+ g="urn:graph:retrieval",
+ )
+ wire = triple_tx.from_pulsar(quad)
+ assert wire["g"] == "urn:graph:retrieval"
+
+ def test_quad_round_trip(self, triple_tx):
+ original = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="v"),
+ g="urn:graph:source",
+ )
+ wire = triple_tx.from_pulsar(original)
+ restored = triple_tx.to_pulsar(wire)
+ assert restored == original
+
+ def test_none_graph_omitted_from_wire(self, triple_tx):
+ triple = Triple(
+ s=Term(type=IRI, iri="http://example.org/s"),
+ p=Term(type=IRI, iri="http://example.org/p"),
+ o=Term(type=LITERAL, value="v"),
+ g=None,
+ )
+ wire = triple_tx.from_pulsar(triple)
+ assert "g" not in wire
+
+ def test_missing_terms_handled(self, triple_tx):
+ data = {}
+ triple = triple_tx.to_pulsar(data)
+ assert triple.s is None
+ assert triple.p is None
+ assert triple.o is None
+
+
+# ---------------------------------------------------------------------------
+# SubgraphTranslator
+# ---------------------------------------------------------------------------
+
+class TestSubgraphTranslator:
+
+ def test_subgraph_round_trip(self):
+ tx = SubgraphTranslator()
+ triples = [
+ Triple(
+ s=Term(type=IRI, iri="http://example.org/A"),
+ p=Term(type=IRI, iri="http://example.org/rel"),
+ o=Term(type=LITERAL, value="v1"),
+ ),
+ Triple(
+ s=Term(type=IRI, iri="http://example.org/B"),
+ p=Term(type=IRI, iri="http://example.org/rel"),
+ o=Term(type=IRI, iri="http://example.org/C"),
+ g="urn:graph:source",
+ ),
+ ]
+ wire_list = tx.from_pulsar(triples)
+ assert len(wire_list) == 2
+ assert wire_list[1]["g"] == "urn:graph:source"
+
+ restored = tx.to_pulsar(wire_list)
+ assert len(restored) == 2
+ assert restored[0] == triples[0]
+ assert restored[1] == triples[1]
+
+ def test_empty_subgraph(self):
+ tx = SubgraphTranslator()
+ assert tx.to_pulsar([]) == []
+ assert tx.from_pulsar([]) == []
diff --git a/tests/unit/test_reliability/__init__.py b/tests/unit/test_reliability/__init__.py
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/tests/unit/test_reliability/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/unit/test_reliability/test_metadata_preservation.py b/tests/unit/test_reliability/test_metadata_preservation.py
new file mode 100644
index 00000000..2fabed58
--- /dev/null
+++ b/tests/unit/test_reliability/test_metadata_preservation.py
@@ -0,0 +1,144 @@
+"""
+Tests for pipeline metadata preservation: DocumentMetadata and
+ProcessingMetadata round-trip through translators, field preservation,
+and default handling.
+"""
+
+import pytest
+
+from trustgraph.schema import DocumentMetadata, ProcessingMetadata, Triple, Term, IRI
+from trustgraph.messaging.translators.metadata import (
+ DocumentMetadataTranslator,
+ ProcessingMetadataTranslator,
+)
+
+
+# ---------------------------------------------------------------------------
+# DocumentMetadata translator
+# ---------------------------------------------------------------------------
+
+class TestDocumentMetadataTranslator:
+
+ def setup_method(self):
+ self.tx = DocumentMetadataTranslator()
+
+ def test_full_round_trip(self):
+ data = {
+ "id": "doc-123",
+ "time": 1710000000,
+ "kind": "application/pdf",
+ "title": "Test Document",
+ "comments": "No comments",
+ "metadata": [],
+ "user": "alice",
+ "tags": ["finance", "q4"],
+ "parent-id": "doc-100",
+ "document-type": "page",
+ }
+ obj = self.tx.to_pulsar(data)
+ assert obj.id == "doc-123"
+ assert obj.time == 1710000000
+ assert obj.kind == "application/pdf"
+ assert obj.title == "Test Document"
+ assert obj.user == "alice"
+ assert obj.tags == ["finance", "q4"]
+ assert obj.parent_id == "doc-100"
+ assert obj.document_type == "page"
+
+ wire = self.tx.from_pulsar(obj)
+ assert wire["id"] == "doc-123"
+ assert wire["user"] == "alice"
+ assert wire["parent-id"] == "doc-100"
+ assert wire["document-type"] == "page"
+
+ def test_defaults_for_missing_fields(self):
+ obj = self.tx.to_pulsar({})
+ assert obj.parent_id == ""
+ assert obj.document_type == "source"
+
+ def test_metadata_triples_preserved(self):
+ triple_wire = [{
+ "s": {"t": "i", "i": "http://example.org/s"},
+ "p": {"t": "i", "i": "http://example.org/p"},
+ "o": {"t": "i", "i": "http://example.org/o"},
+ }]
+ data = {"metadata": triple_wire}
+ obj = self.tx.to_pulsar(data)
+ assert len(obj.metadata) == 1
+ assert obj.metadata[0].s.iri == "http://example.org/s"
+
+ def test_none_metadata_handled(self):
+ data = {"metadata": None}
+ obj = self.tx.to_pulsar(data)
+ assert obj.metadata == []
+
+ def test_empty_tags_preserved(self):
+ data = {"tags": []}
+ obj = self.tx.to_pulsar(data)
+ wire = self.tx.from_pulsar(obj)
+ assert wire["tags"] == []
+
+ def test_falsy_fields_omitted_from_wire(self):
+ """Empty string fields should be omitted from wire format."""
+ obj = DocumentMetadata(id="", time=0, user="")
+ wire = self.tx.from_pulsar(obj)
+ assert "id" not in wire
+ assert "user" not in wire
+
+
+# ---------------------------------------------------------------------------
+# ProcessingMetadata translator
+# ---------------------------------------------------------------------------
+
+class TestProcessingMetadataTranslator:
+
+ def setup_method(self):
+ self.tx = ProcessingMetadataTranslator()
+
+ def test_full_round_trip(self):
+ data = {
+ "id": "proc-1",
+ "document-id": "doc-123",
+ "time": 1710000000,
+ "flow": "default",
+ "user": "alice",
+ "collection": "my-collection",
+ "tags": ["tag1"],
+ }
+ obj = self.tx.to_pulsar(data)
+ assert obj.id == "proc-1"
+ assert obj.document_id == "doc-123"
+ assert obj.flow == "default"
+ assert obj.user == "alice"
+ assert obj.collection == "my-collection"
+ assert obj.tags == ["tag1"]
+
+ wire = self.tx.from_pulsar(obj)
+ assert wire["id"] == "proc-1"
+ assert wire["document-id"] == "doc-123"
+ assert wire["user"] == "alice"
+ assert wire["collection"] == "my-collection"
+
+ def test_missing_fields_use_defaults(self):
+ obj = self.tx.to_pulsar({})
+ assert obj.id is None
+ assert obj.user is None
+ assert obj.collection is None
+
+ def test_tags_none_omitted(self):
+ obj = ProcessingMetadata(tags=None)
+ wire = self.tx.from_pulsar(obj)
+ assert "tags" not in wire
+
+ def test_tags_empty_list_preserved(self):
+ obj = ProcessingMetadata(tags=[])
+ wire = self.tx.from_pulsar(obj)
+ assert wire["tags"] == []
+
+ def test_user_and_collection_preserved(self):
+ """Core pipeline routing fields must survive round-trip."""
+ data = {"user": "bob", "collection": "research"}
+ obj = self.tx.to_pulsar(data)
+ wire = self.tx.from_pulsar(obj)
+ assert wire["user"] == "bob"
+ assert wire["collection"] == "research"
diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py
new file mode 100644
index 00000000..41a5c621
--- /dev/null
+++ b/tests/unit/test_reliability/test_null_embedding_protection.py
@@ -0,0 +1,314 @@
+"""
+Tests for null embedding protection: empty/None vector skipping, entity
+validation, dimension-aware collection creation, and query-time empty
+vector handling.
+
+Tests the pure functions and logic without Qdrant connections.
+"""
+
+import pytest
+from unittest.mock import MagicMock, patch, AsyncMock
+
+from trustgraph.schema import Term, IRI, LITERAL, BLANK
+
+
+# ---------------------------------------------------------------------------
+# Graph embeddings: get_term_value
+# ---------------------------------------------------------------------------
+
+class TestGraphEmbeddingsGetTermValue:
+
+ def test_iri_returns_iri(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
+ t = Term(type=IRI, iri="http://example.org/x")
+ assert get_term_value(t) == "http://example.org/x"
+
+ def test_literal_returns_value(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
+ t = Term(type=LITERAL, value="hello")
+ assert get_term_value(t) == "hello"
+
+ def test_blank_returns_id(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
+ t = Term(type=BLANK, id="_:b0")
+ assert get_term_value(t) == "_:b0"
+
+ def test_none_returns_none(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
+ assert get_term_value(None) is None
+
+ def test_blank_with_value_fallback(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
+ t = Term(type=BLANK, id="", value="fallback")
+ assert get_term_value(t) == "fallback"
+
+
+# ---------------------------------------------------------------------------
+# Document embeddings: null vector protection
+# ---------------------------------------------------------------------------
+
+class TestDocEmbeddingsNullProtection:
+
+ @pytest.mark.asyncio
+ async def test_empty_vector_skipped(self):
+ """Embeddings with empty vectors should be silently skipped."""
+ from trustgraph.storage.doc_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+
+ # Mock collection_exists for config check
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "col1"
+
+ emb = MagicMock()
+ emb.chunk_id = "chunk-1"
+ emb.vector = [] # Empty vector
+ msg.chunks = [emb]
+
+ await proc.store_document_embeddings(msg)
+
+ # No upsert should be called
+ proc.qdrant.upsert.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_none_vector_skipped(self):
+ from trustgraph.storage.doc_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "col1"
+
+ emb = MagicMock()
+ emb.chunk_id = "chunk-1"
+ emb.vector = None # None vector
+ msg.chunks = [emb]
+
+ await proc.store_document_embeddings(msg)
+ proc.qdrant.upsert.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_empty_chunk_id_skipped(self):
+ from trustgraph.storage.doc_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "col1"
+
+ emb = MagicMock()
+ emb.chunk_id = "" # Empty chunk ID
+ emb.vector = [0.1, 0.2, 0.3]
+ msg.chunks = [emb]
+
+ await proc.store_document_embeddings(msg)
+ proc.qdrant.upsert.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_valid_embedding_upserted(self):
+ from trustgraph.storage.doc_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.qdrant.collection_exists.return_value = True
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "col1"
+
+ emb = MagicMock()
+ emb.chunk_id = "chunk-1"
+ emb.vector = [0.1, 0.2, 0.3]
+ msg.chunks = [emb]
+
+ await proc.store_document_embeddings(msg)
+ proc.qdrant.upsert.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_dimension_in_collection_name(self):
+ """Collection name should include vector dimension."""
+ from trustgraph.storage.doc_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.qdrant.collection_exists.return_value = True
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "alice"
+ msg.metadata.collection = "docs"
+
+ emb = MagicMock()
+ emb.chunk_id = "c1"
+ emb.vector = [0.0] * 384 # 384-dim vector
+ msg.chunks = [emb]
+
+ await proc.store_document_embeddings(msg)
+
+ call_args = proc.qdrant.upsert.call_args
+ assert "d_alice_docs_384" in call_args[1]["collection_name"]
+
+
+# ---------------------------------------------------------------------------
+# Graph embeddings: null entity and vector protection
+# ---------------------------------------------------------------------------
+
+class TestGraphEmbeddingsNullProtection:
+
+ @pytest.mark.asyncio
+ async def test_empty_entity_skipped(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "col1"
+
+ entity = MagicMock()
+ entity.entity = Term(type=IRI, iri="") # Empty IRI
+ entity.vector = [0.1, 0.2, 0.3]
+ msg.entities = [entity]
+
+ await proc.store_graph_embeddings(msg)
+ proc.qdrant.upsert.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_none_entity_skipped(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "col1"
+
+ entity = MagicMock()
+ entity.entity = None # Null entity
+ entity.vector = [0.1, 0.2, 0.3]
+ msg.entities = [entity]
+
+ await proc.store_graph_embeddings(msg)
+ proc.qdrant.upsert.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_empty_vector_skipped(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "col1"
+
+ entity = MagicMock()
+ entity.entity = Term(type=IRI, iri="http://example.org/x")
+ entity.vector = [] # Empty vector
+ msg.entities = [entity]
+
+ await proc.store_graph_embeddings(msg)
+ proc.qdrant.upsert.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_valid_entity_and_vector_upserted(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.qdrant.collection_exists.return_value = True
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "col1"
+
+ entity = MagicMock()
+ entity.entity = Term(type=IRI, iri="http://example.org/Alice")
+ entity.vector = [0.1, 0.2, 0.3]
+ entity.chunk_id = "c1"
+ msg.entities = [entity]
+
+ await proc.store_graph_embeddings(msg)
+ proc.qdrant.upsert.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_lazy_collection_creation_on_new_dimension(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.qdrant.collection_exists.return_value = False
+ proc.collection_exists = MagicMock(return_value=True)
+
+ msg = MagicMock()
+ msg.metadata.user = "alice"
+ msg.metadata.collection = "graphs"
+
+ entity = MagicMock()
+ entity.entity = Term(type=IRI, iri="http://example.org/x")
+ entity.vector = [0.0] * 768
+ entity.chunk_id = ""
+ msg.entities = [entity]
+
+ await proc.store_graph_embeddings(msg)
+
+ # Collection should be created with correct dimension
+ proc.qdrant.create_collection.assert_called_once()
+ create_args = proc.qdrant.create_collection.call_args
+ assert create_args[1]["collection_name"] == "t_alice_graphs_768"
+
+
+# ---------------------------------------------------------------------------
+# Collection validation — deleted-while-in-flight protection
+# ---------------------------------------------------------------------------
+
+class TestCollectionValidation:
+
+ @pytest.mark.asyncio
+ async def test_doc_embeddings_dropped_for_deleted_collection(self):
+ from trustgraph.storage.doc_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.collection_exists = MagicMock(return_value=False)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "deleted-col"
+ msg.chunks = [MagicMock()]
+
+ await proc.store_document_embeddings(msg)
+ proc.qdrant.upsert.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_graph_embeddings_dropped_for_deleted_collection(self):
+ from trustgraph.storage.graph_embeddings.qdrant.write import Processor
+
+ proc = Processor.__new__(Processor)
+ proc.qdrant = MagicMock()
+ proc.collection_exists = MagicMock(return_value=False)
+
+ msg = MagicMock()
+ msg.metadata.user = "user1"
+ msg.metadata.collection = "deleted-col"
+ msg.entities = [MagicMock()]
+
+ await proc.store_graph_embeddings(msg)
+ proc.qdrant.upsert.assert_not_called()
diff --git a/tests/unit/test_reliability/test_retry_backoff.py b/tests/unit/test_reliability/test_retry_backoff.py
new file mode 100644
index 00000000..94a3e806
--- /dev/null
+++ b/tests/unit/test_reliability/test_retry_backoff.py
@@ -0,0 +1,153 @@
+"""
+Tests for retry and backoff strategies: Consumer rate-limit retry loop,
+timeout expiry, TooManyRequests exception propagation, and configurable
+retry parameters.
+"""
+
+import asyncio
+import time
+import pytest
+from unittest.mock import MagicMock, AsyncMock, patch
+
+from trustgraph.exceptions import TooManyRequests
+from trustgraph.base.consumer import Consumer
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_consumer(rate_limit_retry_time=10, rate_limit_timeout=7200):
+ """Create a Consumer with minimal mocking."""
+ consumer = Consumer.__new__(Consumer)
+ consumer.rate_limit_retry_time = rate_limit_retry_time
+ consumer.rate_limit_timeout = rate_limit_timeout
+ consumer.metrics = None
+ consumer.consumer = MagicMock()
+ return consumer
+
+
+# ---------------------------------------------------------------------------
+# TooManyRequests exception
+# ---------------------------------------------------------------------------
+
+class TestTooManyRequestsException:
+
+ def test_is_exception(self):
+ assert issubclass(TooManyRequests, Exception)
+
+ def test_with_message(self):
+ err = TooManyRequests("rate limited")
+ assert "rate limited" in str(err)
+
+ def test_without_message(self):
+ err = TooManyRequests()
+ assert isinstance(err, TooManyRequests)
+
+
+# ---------------------------------------------------------------------------
+# Consumer retry configuration
+# ---------------------------------------------------------------------------
+
+class TestConsumerRetryConfig:
+
+ def test_default_retry_time(self):
+ consumer = _make_consumer()
+ assert consumer.rate_limit_retry_time == 10
+
+ def test_default_timeout(self):
+ consumer = _make_consumer()
+ assert consumer.rate_limit_timeout == 7200
+
+ def test_custom_retry_time(self):
+ consumer = _make_consumer(rate_limit_retry_time=5)
+ assert consumer.rate_limit_retry_time == 5
+
+ def test_custom_timeout(self):
+ consumer = _make_consumer(rate_limit_timeout=300)
+ assert consumer.rate_limit_timeout == 300
+
+
+# ---------------------------------------------------------------------------
+# Rate limit metrics
+# ---------------------------------------------------------------------------
+
+class TestRateLimitMetrics:
+
+ def test_metrics_rate_limit_called(self):
+ """Metrics should record rate limit events when available."""
+ consumer = _make_consumer()
+ consumer.metrics = MagicMock()
+
+ # Simulate what the consumer does on rate limit
+ consumer.metrics.rate_limit()
+
+ consumer.metrics.rate_limit.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# Message acknowledgment on error
+# ---------------------------------------------------------------------------
+
+class TestMessageAckOnError:
+
+ def test_consumer_has_negative_acknowledge(self):
+ """Consumer backend should support negative acknowledgment."""
+ consumer = _make_consumer()
+ msg = MagicMock()
+
+ # Simulate negative ack (what happens on timeout expiry)
+ consumer.consumer.negative_acknowledge(msg)
+ consumer.consumer.negative_acknowledge.assert_called_once_with(msg)
+
+
+# ---------------------------------------------------------------------------
+# TooManyRequests propagation across services
+# ---------------------------------------------------------------------------
+
+class TestTooManyRequestsPropagation:
+
+ def test_llm_service_propagates(self):
+ """LLM services should re-raise TooManyRequests for consumer retry."""
+ with pytest.raises(TooManyRequests):
+ raise TooManyRequests()
+
+ def test_embeddings_service_propagates(self):
+ """Embeddings services should re-raise TooManyRequests for consumer retry."""
+ with pytest.raises(TooManyRequests):
+ try:
+ raise TooManyRequests("rate limited")
+ except TooManyRequests as e:
+ # Re-raise pattern used in services
+ assert isinstance(e, TooManyRequests)
+ raise
+
+ def test_too_many_requests_not_caught_by_generic(self):
+ """TooManyRequests should be distinguishable from generic exceptions."""
+ caught_specific = False
+ try:
+ raise TooManyRequests("rate limited")
+ except TooManyRequests:
+ caught_specific = True
+ except Exception:
+ pass
+ assert caught_specific
+
+
+# ---------------------------------------------------------------------------
+# Client-side error type mapping
+# ---------------------------------------------------------------------------
+
+class TestClientErrorTypeMapping:
+
+ def test_too_many_requests_wire_type(self):
+ """The wire format error type for rate limiting is 'too-many-requests'."""
+ from trustgraph.schema import Error
+ err = Error(type="too-many-requests", message="slow down")
+ assert err.type == "too-many-requests"
+
+ def test_generic_error_wire_type(self):
+ from trustgraph.schema import Error
+ err = Error(type="internal-error", message="something broke")
+ assert err.type == "internal-error"
+ assert err.type != "too-many-requests"
diff --git a/tests/unit/test_reliability/test_subscriber_resilience.py b/tests/unit/test_reliability/test_subscriber_resilience.py
new file mode 100644
index 00000000..4aac1161
--- /dev/null
+++ b/tests/unit/test_reliability/test_subscriber_resilience.py
@@ -0,0 +1,233 @@
+"""
+Tests for message queue subscriber resilience: unexpected message handling,
+orphaned message detection, backpressure strategies, graceful draining,
+and timeout recovery.
+"""
+
+import asyncio
+import pytest
+from unittest.mock import MagicMock, AsyncMock, patch
+
+from trustgraph.base.subscriber import Subscriber
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_subscriber(max_size=10, backpressure_strategy="block",
+ drain_timeout=5.0):
+ """Create a Subscriber without connecting to any backend."""
+ backend = MagicMock()
+ sub = Subscriber(
+ backend=backend,
+ topic="test-topic",
+ subscription="test-sub",
+ consumer_name="test",
+ max_size=max_size,
+ backpressure_strategy=backpressure_strategy,
+ drain_timeout=drain_timeout,
+ )
+ sub.consumer = MagicMock()
+ return sub
+
+
+def _make_msg(id=None, value="test-value"):
+ """Create a mock message with optional properties."""
+ msg = MagicMock()
+ if id is not None:
+ msg.properties.return_value = {"id": id}
+ else:
+ msg.properties.side_effect = KeyError("id")
+ msg.value.return_value = value
+ return msg
+
+
+# ---------------------------------------------------------------------------
+# Message property extraction resilience
+# ---------------------------------------------------------------------------
+
+class TestMessagePropertyResilience:
+
+ @pytest.mark.asyncio
+ async def test_missing_id_property_handled(self):
+ """Messages without 'id' property should not crash."""
+ sub = _make_subscriber()
+ msg = MagicMock()
+ msg.properties.side_effect = Exception("no properties")
+ msg.value.return_value = "some-value"
+
+ # Should not raise
+ await sub._process_message(msg)
+
+ # Message should still be acknowledged
+ sub.consumer.acknowledge.assert_called_once_with(msg)
+
+ @pytest.mark.asyncio
+ async def test_message_with_valid_id_delivered(self):
+ """Messages with matching subscriber ID should be delivered."""
+ sub = _make_subscriber()
+ q = await sub.subscribe("req-1")
+
+ msg = _make_msg(id="req-1", value="response-data")
+ await sub._process_message(msg)
+
+ assert not q.empty()
+ assert q.get_nowait() == "response-data"
+ sub.consumer.acknowledge.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# Orphaned message handling
+# ---------------------------------------------------------------------------
+
+class TestOrphanedMessages:
+
+ @pytest.mark.asyncio
+ async def test_orphaned_message_acknowledged(self):
+ """Messages with no matching waiter should still be acknowledged."""
+ sub = _make_subscriber()
+ msg = _make_msg(id="unknown-id", value="orphan")
+
+ await sub._process_message(msg)
+
+ # Orphaned message is acknowledged (not negative-acknowledged)
+ sub.consumer.acknowledge.assert_called_once_with(msg)
+
+ @pytest.mark.asyncio
+ async def test_orphaned_message_not_queued(self):
+ """Orphaned messages should not appear in any subscriber queue."""
+ sub = _make_subscriber()
+ q = await sub.subscribe("req-1")
+
+ msg = _make_msg(id="different-id", value="orphan")
+ await sub._process_message(msg)
+
+ assert q.empty()
+
+
+# ---------------------------------------------------------------------------
+# Backpressure strategies
+# ---------------------------------------------------------------------------
+
+class TestBackpressureStrategies:
+
+ @pytest.mark.asyncio
+ async def test_drop_new_rejects_when_full(self):
+ """drop_new strategy should reject new messages when queue is full."""
+ sub = _make_subscriber(max_size=1, backpressure_strategy="drop_new")
+ q = await sub.subscribe("req-1")
+
+ # Fill the queue
+ msg1 = _make_msg(id="req-1", value="first")
+ await sub._process_message(msg1)
+ assert q.qsize() == 1
+
+ # Second message should be dropped
+ msg2 = _make_msg(id="req-1", value="second")
+ await sub._process_message(msg2)
+
+ # Queue still has only the first message
+ assert q.qsize() == 1
+ assert q.get_nowait() == "first"
+
+ @pytest.mark.asyncio
+ async def test_drop_oldest_evicts_when_full(self):
+ """drop_oldest strategy should evict oldest message when full."""
+ sub = _make_subscriber(max_size=1, backpressure_strategy="drop_oldest")
+ q = await sub.subscribe("req-1")
+
+ msg1 = _make_msg(id="req-1", value="first")
+ await sub._process_message(msg1)
+
+ msg2 = _make_msg(id="req-1", value="second")
+ await sub._process_message(msg2)
+
+ # Queue should have the newer message
+ assert q.qsize() == 1
+ assert q.get_nowait() == "second"
+
+ @pytest.mark.asyncio
+ async def test_block_strategy_delivers(self):
+ """block strategy should deliver messages normally."""
+ sub = _make_subscriber(max_size=10, backpressure_strategy="block")
+ q = await sub.subscribe("req-1")
+
+ msg = _make_msg(id="req-1", value="data")
+ await sub._process_message(msg)
+
+ assert q.get_nowait() == "data"
+
+
+# ---------------------------------------------------------------------------
+# Full subscribers (subscribe_all)
+# ---------------------------------------------------------------------------
+
+class TestFullSubscribers:
+
+ @pytest.mark.asyncio
+ async def test_subscribe_all_receives_all_messages(self):
+ sub = _make_subscriber()
+ q = await sub.subscribe_all("listener-1")
+
+ msg = _make_msg(id="any-id", value="broadcast")
+ await sub._process_message(msg)
+
+ assert q.get_nowait() == "broadcast"
+
+ @pytest.mark.asyncio
+ async def test_multiple_full_subscribers_all_receive(self):
+ sub = _make_subscriber()
+ q1 = await sub.subscribe_all("l1")
+ q2 = await sub.subscribe_all("l2")
+
+ msg = _make_msg(id="any", value="data")
+ await sub._process_message(msg)
+
+ assert q1.get_nowait() == "data"
+ assert q2.get_nowait() == "data"
+
+
+# ---------------------------------------------------------------------------
+# Subscribe / unsubscribe lifecycle
+# ---------------------------------------------------------------------------
+
+class TestSubscribeLifecycle:
+
+ @pytest.mark.asyncio
+ async def test_unsubscribe_removes_queue(self):
+ sub = _make_subscriber()
+ await sub.subscribe("req-1")
+ await sub.unsubscribe("req-1")
+
+ assert "req-1" not in sub.q
+
+ @pytest.mark.asyncio
+ async def test_unsubscribe_nonexistent_is_noop(self):
+ sub = _make_subscriber()
+ await sub.unsubscribe("nonexistent") # Should not raise
+
+ @pytest.mark.asyncio
+ async def test_unsubscribe_all_removes_queue(self):
+ sub = _make_subscriber()
+ await sub.subscribe_all("l1")
+ await sub.unsubscribe_all("l1")
+
+ assert "l1" not in sub.full
+
+
+# ---------------------------------------------------------------------------
+# Pending ack tracking
+# ---------------------------------------------------------------------------
+
+class TestPendingAckTracking:
+
+ @pytest.mark.asyncio
+ async def test_processed_message_cleared_from_pending(self):
+ sub = _make_subscriber()
+ msg = _make_msg(id="req-1", value="data")
+
+ await sub._process_message(msg)
+
+ # After processing, pending_acks should be empty
+ assert len(sub.pending_acks) == 0
diff --git a/tests/unit/test_structured_data/__init__.py b/tests/unit/test_structured_data/__init__.py
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/tests/unit/test_structured_data/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py
new file mode 100644
index 00000000..3222ec83
--- /dev/null
+++ b/tests/unit/test_structured_data/test_row_embeddings_query.py
@@ -0,0 +1,296 @@
+"""
+Tests for row embeddings query service: collection naming, query execution,
+index filtering, result conversion, and error handling.
+"""
+
+import pytest
+from unittest.mock import MagicMock, AsyncMock, patch
+
+from trustgraph.schema import (
+ RowEmbeddingsRequest, RowEmbeddingsResponse,
+ RowIndexMatch, Error,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_processor(qdrant_client=None):
+ """Create a Processor without full FlowProcessor init."""
+ from trustgraph.query.row_embeddings.qdrant.service import Processor
+ proc = Processor.__new__(Processor)
+ proc.qdrant = qdrant_client or MagicMock()
+ return proc
+
+
+def _make_request(vector=None, user="test-user", collection="test-col",
+ schema_name="customers", limit=10, index_name=None):
+ return RowEmbeddingsRequest(
+ vector=vector or [0.1, 0.2, 0.3],
+ user=user,
+ collection=collection,
+ schema_name=schema_name,
+ limit=limit,
+ index_name=index_name or "",
+ )
+
+
+def _make_search_point(index_name, index_value, text, score):
+ point = MagicMock()
+ point.payload = {
+ "index_name": index_name,
+ "index_value": index_value,
+ "text": text,
+ }
+ point.score = score
+ return point
+
+
+# ---------------------------------------------------------------------------
+# sanitize_name
+# ---------------------------------------------------------------------------
+
+class TestSanitizeName:
+
+ def test_simple_name(self):
+ proc = _make_processor()
+ assert proc.sanitize_name("customers") == "customers"
+
+ def test_special_chars_replaced(self):
+ proc = _make_processor()
+ assert proc.sanitize_name("my-schema.v2") == "my_schema_v2"
+
+ def test_leading_digit_prefixed(self):
+ proc = _make_processor()
+ result = proc.sanitize_name("123schema")
+ assert result.startswith("r_")
+ assert "123schema" in result
+
+ def test_uppercase_lowercased(self):
+ proc = _make_processor()
+ assert proc.sanitize_name("MySchema") == "myschema"
+
+ def test_spaces_replaced(self):
+ proc = _make_processor()
+ assert proc.sanitize_name("my schema") == "my_schema"
+
+
+# ---------------------------------------------------------------------------
+# find_collection
+# ---------------------------------------------------------------------------
+
+class TestFindCollection:
+
+ def test_finds_matching_collection(self):
+ proc = _make_processor()
+ mock_coll = MagicMock()
+ mock_coll.name = "rows_test_user_test_col_customers_384"
+
+ mock_collections = MagicMock()
+ mock_collections.collections = [mock_coll]
+ proc.qdrant.get_collections.return_value = mock_collections
+
+ result = proc.find_collection("test-user", "test-col", "customers")
+
+ # Prefix: rows_test_user_test_col_customers_
+ assert result == "rows_test_user_test_col_customers_384"
+
+ def test_returns_none_when_no_match(self):
+ proc = _make_processor()
+ mock_coll = MagicMock()
+ mock_coll.name = "rows_other_user_other_col_schema_768"
+
+ mock_collections = MagicMock()
+ mock_collections.collections = [mock_coll]
+ proc.qdrant.get_collections.return_value = mock_collections
+
+ result = proc.find_collection("test-user", "test-col", "customers")
+ assert result is None
+
+ def test_returns_none_on_error(self):
+ proc = _make_processor()
+ proc.qdrant.get_collections.side_effect = Exception("connection error")
+
+ result = proc.find_collection("user", "col", "schema")
+ assert result is None
+
+
+# ---------------------------------------------------------------------------
+# query_row_embeddings
+# ---------------------------------------------------------------------------
+
+class TestQueryRowEmbeddings:
+
+ @pytest.mark.asyncio
+ async def test_empty_vector_returns_empty(self):
+ proc = _make_processor()
+ request = _make_request(vector=[])
+
+ result = await proc.query_row_embeddings(request)
+ assert result == []
+
+ @pytest.mark.asyncio
+ async def test_no_collection_returns_empty(self):
+ proc = _make_processor()
+ proc.find_collection = MagicMock(return_value=None)
+ request = _make_request()
+
+ result = await proc.query_row_embeddings(request)
+ assert result == []
+
+ @pytest.mark.asyncio
+ async def test_successful_query_returns_matches(self):
+ proc = _make_processor()
+ proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
+
+ points = [
+ _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
+ _make_search_point("address", ["123 Main St"], "123 Main St", 0.82),
+ ]
+ mock_result = MagicMock()
+ mock_result.points = points
+ proc.qdrant.query_points.return_value = mock_result
+
+ request = _make_request()
+ result = await proc.query_row_embeddings(request)
+
+ assert len(result) == 2
+ assert isinstance(result[0], RowIndexMatch)
+ assert result[0].index_name == "name"
+ assert result[0].index_value == ["Alice Smith"]
+ assert result[0].score == 0.95
+ assert result[1].index_name == "address"
+
+ @pytest.mark.asyncio
+ async def test_index_name_filter_applied(self):
+ """When index_name is specified, a Qdrant filter should be used."""
+ proc = _make_processor()
+ proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
+
+ mock_result = MagicMock()
+ mock_result.points = []
+ proc.qdrant.query_points.return_value = mock_result
+
+ request = _make_request(index_name="address")
+ await proc.query_row_embeddings(request)
+
+ call_kwargs = proc.qdrant.query_points.call_args[1]
+ assert call_kwargs["query_filter"] is not None
+
+ @pytest.mark.asyncio
+ async def test_no_index_name_no_filter(self):
+ """When index_name is empty, no filter should be applied."""
+ proc = _make_processor()
+ proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
+
+ mock_result = MagicMock()
+ mock_result.points = []
+ proc.qdrant.query_points.return_value = mock_result
+
+ request = _make_request(index_name="")
+ await proc.query_row_embeddings(request)
+
+ call_kwargs = proc.qdrant.query_points.call_args[1]
+ assert call_kwargs["query_filter"] is None
+
+ @pytest.mark.asyncio
+ async def test_missing_payload_fields_default(self):
+ """Points with missing payload fields should use defaults."""
+ proc = _make_processor()
+ proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
+
+ point = MagicMock()
+ point.payload = {} # Empty payload
+ point.score = 0.5
+
+ mock_result = MagicMock()
+ mock_result.points = [point]
+ proc.qdrant.query_points.return_value = mock_result
+
+ request = _make_request()
+ result = await proc.query_row_embeddings(request)
+
+ assert len(result) == 1
+ assert result[0].index_name == ""
+ assert result[0].index_value == []
+ assert result[0].text == ""
+
+ @pytest.mark.asyncio
+ async def test_qdrant_error_propagates(self):
+ proc = _make_processor()
+ proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
+ proc.qdrant.query_points.side_effect = Exception("qdrant down")
+
+ request = _make_request()
+
+ with pytest.raises(Exception, match="qdrant down"):
+ await proc.query_row_embeddings(request)
+
+
+# ---------------------------------------------------------------------------
+# on_message handler
+# ---------------------------------------------------------------------------
+
+class TestOnMessage:
+
+ @pytest.mark.asyncio
+ async def test_successful_message_sends_response(self):
+ proc = _make_processor()
+ proc.query_row_embeddings = AsyncMock(return_value=[
+ RowIndexMatch(index_name="name", index_value=["Alice"],
+ text="Alice", score=0.9),
+ ])
+
+ mock_pub = AsyncMock()
+ flow = lambda name: mock_pub
+
+ msg = MagicMock()
+ msg.value.return_value = _make_request()
+ msg.properties.return_value = {"id": "req-1"}
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ sent = mock_pub.send.call_args[0][0]
+ assert isinstance(sent, RowEmbeddingsResponse)
+ assert sent.error is None
+ assert len(sent.matches) == 1
+
+ @pytest.mark.asyncio
+ async def test_error_sends_error_response(self):
+ proc = _make_processor()
+ proc.query_row_embeddings = AsyncMock(
+ side_effect=Exception("query failed")
+ )
+
+ mock_pub = AsyncMock()
+ flow = lambda name: mock_pub
+
+ msg = MagicMock()
+ msg.value.return_value = _make_request()
+ msg.properties.return_value = {"id": "req-2"}
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ sent = mock_pub.send.call_args[0][0]
+ assert sent.error is not None
+ assert sent.error.type == "row-embeddings-query-error"
+ assert "query failed" in sent.error.message
+ assert sent.matches == []
+
+ @pytest.mark.asyncio
+ async def test_message_id_preserved(self):
+ proc = _make_processor()
+ proc.query_row_embeddings = AsyncMock(return_value=[])
+
+ mock_pub = AsyncMock()
+ flow = lambda name: mock_pub
+
+ msg = MagicMock()
+ msg.value.return_value = _make_request()
+ msg.properties.return_value = {"id": "unique-42"}
+
+ await proc.on_message(msg, MagicMock(), flow)
+
+ props = mock_pub.send.call_args[1]["properties"]
+ assert props["id"] == "unique-42"
diff --git a/tests/unit/test_structured_data/test_type_detector.py b/tests/unit/test_structured_data/test_type_detector.py
new file mode 100644
index 00000000..7ce7060a
--- /dev/null
+++ b/tests/unit/test_structured_data/test_type_detector.py
@@ -0,0 +1,235 @@
+"""
+Tests for structured data type detection: CSV, JSON, XML format detection,
+CSV option detection (delimiter, header), and helper functions.
+"""
+
+import pytest
+
+from trustgraph.retrieval.structured_diag.type_detector import (
+ detect_data_type,
+ _check_json_format,
+ _check_xml_format,
+ _check_csv_format,
+ _check_csv_with_delimiter,
+ detect_csv_options,
+ _is_numeric,
+)
+
+
+# ---------------------------------------------------------------------------
+# detect_data_type (top-level dispatcher)
+# ---------------------------------------------------------------------------
+
+class TestDetectDataType:
+
+ def test_empty_string_returns_none(self):
+ detected, confidence = detect_data_type("")
+ assert detected is None
+ assert confidence == 0.0
+
+ def test_whitespace_only_returns_none(self):
+ detected, confidence = detect_data_type(" \n \t ")
+ assert detected is None
+ assert confidence == 0.0
+
+ def test_none_returns_none(self):
+ detected, confidence = detect_data_type(None)
+ assert detected is None
+ assert confidence == 0.0
+
+ def test_json_object_detected(self):
+ detected, confidence = detect_data_type('{"name": "Alice"}')
+ assert detected == "json"
+ assert confidence > 0.5
+
+ def test_json_array_detected(self):
+ detected, confidence = detect_data_type('[{"id": 1}, {"id": 2}]')
+ assert detected == "json"
+ assert confidence > 0.5
+
+ def test_xml_with_declaration_detected(self):
+ detected, confidence = detect_data_type('')
+ assert detected == "xml"
+ assert confidence > 0.5
+
+ def test_xml_without_declaration_detected(self):
+ detected, confidence = detect_data_type('- val
')
+ assert detected == "xml"
+ assert confidence > 0.5
+
+ def test_csv_detected(self):
+ data = "name,age,city\nAlice,30,NYC\nBob,25,LA"
+ detected, confidence = detect_data_type(data)
+ assert detected == "csv"
+ assert confidence > 0.5
+
+ def test_plain_text_falls_through_to_csv(self):
+ """Non-JSON/XML text defaults to CSV detection."""
+ detected, confidence = detect_data_type("just some text")
+ assert detected == "csv"
+
+
+# ---------------------------------------------------------------------------
+# _check_json_format
+# ---------------------------------------------------------------------------
+
+class TestCheckJsonFormat:
+
+ def test_valid_json_object(self):
+ assert _check_json_format('{"key": "value"}') > 0.9
+
+ def test_valid_json_array_of_objects(self):
+ assert _check_json_format('[{"id": 1}, {"id": 2}]') >= 0.9
+
+ def test_valid_json_array_of_primitives(self):
+ score = _check_json_format('[1, 2, 3]')
+ assert score > 0.5
+ assert score < 0.9 # Lower confidence for non-object arrays
+
+ def test_empty_json_object(self):
+ assert _check_json_format('{}') > 0.5
+
+ def test_invalid_json(self):
+ assert _check_json_format('{invalid json}') == 0.0
+
+ def test_non_json_starting_char(self):
+ assert _check_json_format('hello world') == 0.0
+
+ def test_empty_array(self):
+ score = _check_json_format('[]')
+ assert score > 0.0 # Parsed successfully but empty
+
+
+# ---------------------------------------------------------------------------
+# _check_xml_format
+# ---------------------------------------------------------------------------
+
+class TestCheckXmlFormat:
+
+ def test_valid_xml(self):
+ assert _check_xml_format('- val
') == 0.9
+
+ def test_xml_with_declaration(self):
+ xml = '- test
'
+ assert _check_xml_format(xml) == 0.9
+
+ def test_malformed_xml(self):
+ score = _check_xml_format('')
+ # Has < and check fails since no closing tag
+ assert score < 0.9
+
+ def test_not_xml(self):
+ assert _check_xml_format('just text') == 0.0
+
+ def test_incomplete_xml_tag(self):
+ score = _check_xml_format('')
+ # Starts with < but no closing tag
+ assert score <= 0.1
+
+
+# ---------------------------------------------------------------------------
+# _check_csv_format and _check_csv_with_delimiter
+# ---------------------------------------------------------------------------
+
+class TestCheckCsvFormat:
+
+ def test_valid_csv_comma(self):
+ data = "name,age,city\nAlice,30,NYC\nBob,25,LA"
+ assert _check_csv_format(data) > 0.7
+
+ def test_valid_csv_semicolon(self):
+ data = "name;age;city\nAlice;30;NYC\nBob;25;LA"
+ assert _check_csv_format(data) > 0.7
+
+ def test_valid_csv_tab(self):
+ data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA"
+ assert _check_csv_format(data) > 0.7
+
+ def test_valid_csv_pipe(self):
+ data = "name|age|city\nAlice|30|NYC\nBob|25|LA"
+ assert _check_csv_format(data) > 0.7
+
+ def test_single_line_not_csv(self):
+ assert _check_csv_format("just one line") == 0.0
+
+ def test_single_column_not_csv(self):
+ data = "a\nb\nc"
+ assert _check_csv_with_delimiter(data, ",") == 0.0
+
+ def test_inconsistent_columns_low_score(self):
+ data = "a,b,c\n1,2\n3,4,5,6"
+ score = _check_csv_with_delimiter(data, ",")
+ assert score < 0.7
+
+ def test_many_rows_higher_score(self):
+ rows = ["name,age,city"] + [f"person{i},{20+i},city{i}" for i in range(20)]
+ data = "\n".join(rows)
+ score = _check_csv_format(data)
+ assert score > 0.8
+
+
+# ---------------------------------------------------------------------------
+# detect_csv_options
+# ---------------------------------------------------------------------------
+
+class TestDetectCsvOptions:
+
+ def test_comma_delimiter_detected(self):
+ data = "name,age,city\nAlice,30,NYC\nBob,25,LA"
+ options = detect_csv_options(data)
+ assert options["delimiter"] == ","
+
+ def test_semicolon_delimiter_detected(self):
+ data = "name;age;city\nAlice;30;NYC\nBob;25;LA"
+ options = detect_csv_options(data)
+ assert options["delimiter"] == ";"
+
+ def test_tab_delimiter_detected(self):
+ data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA"
+ options = detect_csv_options(data)
+ assert options["delimiter"] == "\t"
+
+ def test_header_detected_when_first_row_text(self):
+ data = "name,age,salary\nAlice,30,50000\nBob,25,45000"
+ options = detect_csv_options(data)
+ assert options["has_header"] is True
+
+ def test_no_header_when_all_numeric(self):
+ data = "1,2,3\n4,5,6\n7,8,9"
+ options = detect_csv_options(data)
+ assert options["has_header"] is False
+
+ def test_single_line_returns_defaults(self):
+ options = detect_csv_options("just one line")
+ assert options["delimiter"] == ","
+ assert options["has_header"] is True
+
+ def test_encoding_default(self):
+ data = "a,b\n1,2"
+ options = detect_csv_options(data)
+ assert options["encoding"] == "utf-8"
+
+
+# ---------------------------------------------------------------------------
+# _is_numeric helper
+# ---------------------------------------------------------------------------
+
+class TestIsNumeric:
+
+ def test_integer(self):
+ assert _is_numeric("42") is True
+
+ def test_float(self):
+ assert _is_numeric("3.14") is True
+
+ def test_negative(self):
+ assert _is_numeric("-10") is True
+
+ def test_text(self):
+ assert _is_numeric("hello") is False
+
+ def test_empty(self):
+ assert _is_numeric("") is False
+
+ def test_whitespace_padded(self):
+ assert _is_numeric(" 42 ") is True
diff --git a/tests/unit/test_text_completion/test_azure_openai_streaming.py b/tests/unit/test_text_completion/test_azure_openai_streaming.py
new file mode 100644
index 00000000..b2f5a003
--- /dev/null
+++ b/tests/unit/test_text_completion/test_azure_openai_streaming.py
@@ -0,0 +1,182 @@
+"""
+Tests for Azure OpenAI streaming: model/temperature override during streaming,
+RateLimitError → TooManyRequests conversion, chunk iteration, and final token
+count emission.
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+from unittest import IsolatedAsyncioTestCase
+
+from trustgraph.model.text_completion.azure_openai.llm import Processor
+from trustgraph.base import LlmChunk
+from trustgraph.exceptions import TooManyRequests
+
+
+def _make_processor(mock_azure_openai_class, model="gpt-4"):
+ """Create a Processor with mocked base classes."""
+ with patch('trustgraph.base.async_processor.AsyncProcessor.__init__',
+ return_value=None), \
+ patch('trustgraph.base.llm_service.LlmService.__init__',
+ return_value=None):
+ proc = Processor(
+ endpoint="https://test.openai.azure.com/",
+ token="test-token",
+ api_version="2024-12-01-preview",
+ model=model,
+ temperature=0.0,
+ max_output=4192,
+ concurrency=1,
+ taskgroup=AsyncMock(),
+ id="test-processor",
+ )
+ return proc
+
+
+def _make_stream_chunk(content=None, usage=None):
+ """Create a mock streaming chunk."""
+ chunk = MagicMock()
+ if content:
+ chunk.choices = [MagicMock()]
+ chunk.choices[0].delta.content = content
+ else:
+ chunk.choices = []
+ chunk.usage = usage
+ return chunk
+
+
+class TestAzureOpenAIStreaming(IsolatedAsyncioTestCase):
+
+ @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
+ async def test_streaming_yields_chunks(self, mock_azure_class):
+ mock_client = MagicMock()
+ mock_azure_class.return_value = mock_client
+ proc = _make_processor(mock_azure_class)
+
+ usage = MagicMock()
+ usage.prompt_tokens = 10
+ usage.completion_tokens = 5
+
+ stream_data = [
+ _make_stream_chunk(content="Hello"),
+ _make_stream_chunk(content=" world"),
+ _make_stream_chunk(usage=usage),
+ ]
+ mock_client.chat.completions.create.return_value = iter(stream_data)
+
+ results = []
+ async for chunk in proc.generate_content_stream("sys", "user"):
+ results.append(chunk)
+
+ assert len(results) == 3 # 2 content + 1 final
+ assert results[0].text == "Hello"
+ assert results[0].is_final is False
+ assert results[1].text == " world"
+ assert results[2].is_final is True
+ assert results[2].in_token == 10
+ assert results[2].out_token == 5
+
+ @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
+ async def test_streaming_model_override(self, mock_azure_class):
+ mock_client = MagicMock()
+ mock_azure_class.return_value = mock_client
+ proc = _make_processor(mock_azure_class, model="gpt-4")
+
+ usage = MagicMock()
+ usage.prompt_tokens = 5
+ usage.completion_tokens = 2
+
+ stream_data = [
+ _make_stream_chunk(content="ok"),
+ _make_stream_chunk(usage=usage),
+ ]
+ mock_client.chat.completions.create.return_value = iter(stream_data)
+
+ results = []
+ async for chunk in proc.generate_content_stream(
+ "sys", "user", model="gpt-4o"
+ ):
+ results.append(chunk)
+
+ # All chunks carry overridden model
+ for r in results:
+ assert r.model == "gpt-4o"
+
+ # Verify API call used overridden model
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["model"] == "gpt-4o"
+
+ @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
+ async def test_streaming_temperature_override(self, mock_azure_class):
+ mock_client = MagicMock()
+ mock_azure_class.return_value = mock_client
+ proc = _make_processor(mock_azure_class)
+
+ usage = MagicMock()
+ usage.prompt_tokens = 5
+ usage.completion_tokens = 2
+
+ stream_data = [_make_stream_chunk(usage=usage)]
+ mock_client.chat.completions.create.return_value = iter(stream_data)
+
+ async for _ in proc.generate_content_stream(
+ "sys", "user", temperature=0.7
+ ):
+ pass
+
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["temperature"] == 0.7
+
+ @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
+ async def test_streaming_rate_limit_raises_too_many_requests(self, mock_azure_class):
+ from openai import RateLimitError
+
+ mock_client = MagicMock()
+ mock_azure_class.return_value = mock_client
+ proc = _make_processor(mock_azure_class)
+
+ mock_client.chat.completions.create.side_effect = RateLimitError(
+ "Rate limit exceeded", response=MagicMock(), body=None
+ )
+
+ with pytest.raises(TooManyRequests):
+ async for _ in proc.generate_content_stream("sys", "user"):
+ pass
+
+ @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
+ async def test_streaming_generic_exception_propagates(self, mock_azure_class):
+ mock_client = MagicMock()
+ mock_azure_class.return_value = mock_client
+ proc = _make_processor(mock_azure_class)
+
+ mock_client.chat.completions.create.side_effect = Exception("API down")
+
+ with pytest.raises(Exception, match="API down"):
+ async for _ in proc.generate_content_stream("sys", "user"):
+ pass
+
+ @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
+ async def test_streaming_passes_stream_options(self, mock_azure_class):
+ mock_client = MagicMock()
+ mock_azure_class.return_value = mock_client
+ proc = _make_processor(mock_azure_class)
+
+ usage = MagicMock()
+ usage.prompt_tokens = 0
+ usage.completion_tokens = 0
+ stream_data = [_make_stream_chunk(usage=usage)]
+ mock_client.chat.completions.create.return_value = iter(stream_data)
+
+ async for _ in proc.generate_content_stream("sys", "user"):
+ pass
+
+ call_kwargs = mock_client.chat.completions.create.call_args[1]
+ assert call_kwargs["stream"] is True
+ assert call_kwargs["stream_options"] == {"include_usage": True}
+
+ @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
+ async def test_supports_streaming(self, mock_azure_class):
+ mock_client = MagicMock()
+ mock_azure_class.return_value = mock_client
+ proc = _make_processor(mock_azure_class)
+ assert proc.supports_streaming() is True
diff --git a/tests/unit/test_text_completion/test_azure_streaming.py b/tests/unit/test_text_completion/test_azure_streaming.py
new file mode 100644
index 00000000..ff32e59d
--- /dev/null
+++ b/tests/unit/test_text_completion/test_azure_streaming.py
@@ -0,0 +1,199 @@
+"""
+Tests for Azure serverless endpoint streaming: model override during streaming,
+HTTP 429 during streaming, SSE chunk parsing, and final token count emission.
+"""
+
+import json
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+from unittest import IsolatedAsyncioTestCase
+
+from trustgraph.model.text_completion.azure.llm import Processor
+from trustgraph.base import LlmChunk
+from trustgraph.exceptions import TooManyRequests
+
+
+def _make_processor(mock_requests, model="AzureAI", temperature=0.0):
+ """Create a Processor with mocked base classes."""
+ with patch('trustgraph.base.async_processor.AsyncProcessor.__init__',
+ return_value=None), \
+ patch('trustgraph.base.llm_service.LlmService.__init__',
+ return_value=None):
+ proc = Processor(
+ endpoint="https://test.azure.com/v1/chat/completions",
+ token="test-token",
+ temperature=temperature,
+ max_output=4192,
+ model=model,
+ concurrency=1,
+ taskgroup=AsyncMock(),
+ id="test-processor",
+ )
+ return proc
+
+
+def _sse_lines(*data_items):
+ """Build SSE byte lines from data items. '[DONE]' is appended."""
+ lines = []
+ for item in data_items:
+ if isinstance(item, dict):
+ lines.append(f"data: {json.dumps(item)}".encode())
+ else:
+ lines.append(f"data: {item}".encode())
+ lines.append(b"data: [DONE]")
+ return lines
+
+
+class TestAzureServerlessStreaming(IsolatedAsyncioTestCase):
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ async def test_streaming_yields_chunks(self, mock_requests):
+ proc = _make_processor(mock_requests)
+
+ chunks = [
+ {"choices": [{"delta": {"content": "Hello"}}]},
+ {"choices": [{"delta": {"content": " world"}}]},
+ {"usage": {"prompt_tokens": 10, "completion_tokens": 5}},
+ ]
+
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.iter_lines.return_value = _sse_lines(*chunks)
+ mock_requests.post.return_value = mock_response
+
+ results = []
+ async for chunk in proc.generate_content_stream("sys", "user"):
+ results.append(chunk)
+
+ # Content chunks + final chunk
+ assert len(results) == 3
+ assert results[0].text == "Hello"
+ assert results[0].is_final is False
+ assert results[1].text == " world"
+ assert results[1].is_final is False
+ assert results[2].is_final is True
+ assert results[2].in_token == 10
+ assert results[2].out_token == 5
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ async def test_streaming_model_override(self, mock_requests):
+ proc = _make_processor(mock_requests, model="default-model")
+
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.iter_lines.return_value = _sse_lines(
+ {"choices": [{"delta": {"content": "ok"}}]},
+ {"usage": {"prompt_tokens": 5, "completion_tokens": 2}},
+ )
+ mock_requests.post.return_value = mock_response
+
+ results = []
+ async for chunk in proc.generate_content_stream(
+ "sys", "user", model="override-model"
+ ):
+ results.append(chunk)
+
+ # All chunks should carry the overridden model name
+ for r in results:
+ assert r.model == "override-model"
+
+ # Verify the request body used the overridden model
+ call_args = mock_requests.post.call_args
+ body = json.loads(call_args[1]["data"])
+ assert body["model"] == "override-model"
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ async def test_streaming_temperature_override(self, mock_requests):
+ proc = _make_processor(mock_requests, temperature=0.0)
+
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.iter_lines.return_value = _sse_lines(
+ {"choices": [{"delta": {"content": "ok"}}]},
+ {"usage": {"prompt_tokens": 5, "completion_tokens": 2}},
+ )
+ mock_requests.post.return_value = mock_response
+
+ results = []
+ async for chunk in proc.generate_content_stream(
+ "sys", "user", temperature=0.9
+ ):
+ results.append(chunk)
+
+ call_args = mock_requests.post.call_args
+ body = json.loads(call_args[1]["data"])
+ assert body["temperature"] == 0.9
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ async def test_streaming_429_raises_too_many_requests(self, mock_requests):
+ proc = _make_processor(mock_requests)
+
+ mock_response = MagicMock()
+ mock_response.status_code = 429
+ mock_requests.post.return_value = mock_response
+
+ with pytest.raises(TooManyRequests):
+ async for _ in proc.generate_content_stream("sys", "user"):
+ pass
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ async def test_streaming_http_error_raises_runtime(self, mock_requests):
+ proc = _make_processor(mock_requests)
+
+ mock_response = MagicMock()
+ mock_response.status_code = 503
+ mock_response.text = "Service Unavailable"
+ mock_requests.post.return_value = mock_response
+
+ with pytest.raises(RuntimeError, match="HTTP 503"):
+ async for _ in proc.generate_content_stream("sys", "user"):
+ pass
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ async def test_streaming_includes_stream_options(self, mock_requests):
+ """Verify stream=True and stream_options in request body."""
+ proc = _make_processor(mock_requests)
+
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.iter_lines.return_value = _sse_lines(
+ {"usage": {"prompt_tokens": 0, "completion_tokens": 0}},
+ )
+ mock_requests.post.return_value = mock_response
+
+ async for _ in proc.generate_content_stream("sys", "user"):
+ pass
+
+ call_args = mock_requests.post.call_args
+ body = json.loads(call_args[1]["data"])
+ assert body["stream"] is True
+ assert body["stream_options"]["include_usage"] is True
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ async def test_streaming_malformed_json_skipped(self, mock_requests):
+ """Malformed JSON chunks should be skipped, not crash the stream."""
+ proc = _make_processor(mock_requests)
+
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ lines = [
+ b"data: {not valid json}",
+ f'data: {json.dumps({"choices": [{"delta": {"content": "ok"}}]})}'.encode(),
+ f'data: {json.dumps({"usage": {"prompt_tokens": 1, "completion_tokens": 1}})}'.encode(),
+ b"data: [DONE]",
+ ]
+ mock_response.iter_lines.return_value = lines
+ mock_requests.post.return_value = mock_response
+
+ results = []
+ async for chunk in proc.generate_content_stream("sys", "user"):
+ results.append(chunk)
+
+ # Should get the valid content chunk + final chunk
+ assert any(r.text == "ok" for r in results)
+ assert results[-1].is_final is True
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ async def test_streaming_supports_streaming_flag(self, mock_requests):
+ proc = _make_processor(mock_requests)
+ assert proc.supports_streaming() is True
diff --git a/tests/unit/test_text_completion/test_rate_limit_contract.py b/tests/unit/test_text_completion/test_rate_limit_contract.py
new file mode 100644
index 00000000..c9df217b
--- /dev/null
+++ b/tests/unit/test_text_completion/test_rate_limit_contract.py
@@ -0,0 +1,140 @@
+"""
+Cross-provider rate limit contract tests: verify that every LLM provider
+that handles rate limits converts its provider-specific exception to
+TooManyRequests consistently.
+
+Also tests the client-side error translation in the base client.
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+from unittest import IsolatedAsyncioTestCase
+
+from trustgraph.exceptions import TooManyRequests
+
+
+class TestAzureServerless429(IsolatedAsyncioTestCase):
+ """Azure serverless endpoint: HTTP 429 → TooManyRequests"""
+
+ @patch('trustgraph.model.text_completion.azure.llm.requests')
+ @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
+ @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
+ async def test_http_429_raises_too_many_requests(self, _llm, _async, mock_requests):
+ from trustgraph.model.text_completion.azure.llm import Processor
+ proc = Processor(
+ endpoint="https://test.azure.com/v1/chat",
+ token="t", concurrency=1, taskgroup=AsyncMock(), id="t",
+ )
+ mock_response = MagicMock()
+ mock_response.status_code = 429
+ mock_requests.post.return_value = mock_response
+
+ with pytest.raises(TooManyRequests):
+ await proc.generate_content("sys", "prompt")
+
+
+class TestAzureOpenAIRateLimit(IsolatedAsyncioTestCase):
+ """Azure OpenAI: openai.RateLimitError → TooManyRequests"""
+
+ @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
+ @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
+ @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
+ async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cls):
+ from openai import RateLimitError
+ from trustgraph.model.text_completion.azure_openai.llm import Processor
+ mock_client = MagicMock()
+ mock_cls.return_value = mock_client
+ proc = Processor(
+ endpoint="https://test.openai.azure.com/", token="t",
+ model="gpt-4", concurrency=1, taskgroup=AsyncMock(), id="t",
+ )
+ mock_client.chat.completions.create.side_effect = RateLimitError(
+ "rate limited", response=MagicMock(), body=None
+ )
+
+ with pytest.raises(TooManyRequests):
+ await proc.generate_content("sys", "prompt")
+
+
+class TestOpenAIRateLimit(IsolatedAsyncioTestCase):
+ """OpenAI: openai.RateLimitError → TooManyRequests"""
+
+ @patch('trustgraph.model.text_completion.openai.llm.OpenAI')
+ @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
+ @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
+ async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cls):
+ from openai import RateLimitError
+ from trustgraph.model.text_completion.openai.llm import Processor
+ mock_client = MagicMock()
+ mock_cls.return_value = mock_client
+ proc = Processor(
+ api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
+ )
+ mock_client.chat.completions.create.side_effect = RateLimitError(
+ "rate limited", response=MagicMock(), body=None
+ )
+
+ with pytest.raises(TooManyRequests):
+ await proc.generate_content("sys", "prompt")
+
+
+class TestClaudeRateLimit(IsolatedAsyncioTestCase):
+ """Claude/Anthropic: anthropic.RateLimitError → TooManyRequests"""
+
+ @patch('trustgraph.model.text_completion.claude.llm.anthropic')
+ @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
+ @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
+ async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_anthropic):
+ from trustgraph.model.text_completion.claude.llm import Processor
+
+ mock_client = MagicMock()
+ mock_anthropic.Anthropic.return_value = mock_client
+
+ proc = Processor(
+ api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
+ )
+
+ mock_anthropic.RateLimitError = type("RateLimitError", (Exception,), {})
+ mock_client.messages.create.side_effect = mock_anthropic.RateLimitError(
+ "rate limited"
+ )
+
+ with pytest.raises(TooManyRequests):
+ await proc.generate_content("sys", "prompt")
+
+
+class TestCohereRateLimit(IsolatedAsyncioTestCase):
+ """Cohere: cohere.TooManyRequestsError → TooManyRequests"""
+
+ @patch('trustgraph.model.text_completion.cohere.llm.cohere')
+ @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
+ @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
+ async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere):
+ from trustgraph.model.text_completion.cohere.llm import Processor
+
+ mock_client = MagicMock()
+ mock_cohere.Client.return_value = mock_client
+
+ proc = Processor(
+ api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
+ )
+
+ mock_cohere.TooManyRequestsError = type(
+ "TooManyRequestsError", (Exception,), {}
+ )
+ mock_client.chat.side_effect = mock_cohere.TooManyRequestsError(
+ "rate limited"
+ )
+
+ with pytest.raises(TooManyRequests):
+ await proc.generate_content("sys", "prompt")
+
+
+class TestClientSideRateLimitTranslation:
+ """Client base class: error type 'too-many-requests' → TooManyRequests"""
+
+ def test_error_type_mapping(self):
+ """The wire format error type string is 'too-many-requests'."""
+ from trustgraph.schema import Error
+ err = Error(type="too-many-requests", message="slow down")
+ assert err.type == "too-many-requests"