Updated test suite for explainability & provenance (#696)

* Provenance tests

* Embeddings tests

* Test librarian

* Test triples stream

* Test concurrency

* Entity centric graph writes

* Agent tool service tests

* Structured data tests

* RDF tests

* Addition LLM tests

* Reliability tests
This commit is contained in:
cybermaggedon 2026-03-13 14:27:42 +00:00 committed by GitHub
parent e6623fc915
commit 29b4300808
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 8799 additions and 0 deletions

View file

@ -0,0 +1,624 @@
"""
Tests for tool service lifecycle, invoke contract, streaming responses,
multi-tenancy, and error propagation.
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
classes rather than plain dicts.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
ToolServiceRequest, ToolServiceResponse, Error,
ToolRequest, ToolResponse,
)
from trustgraph.exceptions import TooManyRequests
# ---------------------------------------------------------------------------
# DynamicToolService tests
# ---------------------------------------------------------------------------
class TestDynamicToolServiceInvokeContract:
@pytest.mark.asyncio
async def test_base_invoke_raises_not_implemented(self):
"""Base class invoke() should raise NotImplementedError."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
with pytest.raises(NotImplementedError):
await svc.invoke("user", {}, {})
@pytest.mark.asyncio
async def test_on_request_calls_invoke_with_parsed_args(self):
"""on_request should JSON-parse config/arguments and pass to invoke."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
calls = []
async def tracking_invoke(user, config, arguments):
calls.append({"user": user, "config": config, "arguments": arguments})
return "ok"
svc.invoke = tracking_invoke
# Ensure the class-level metric exists
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user="alice",
config='{"style": "pun"}',
arguments='{"topic": "cats"}',
)
msg.properties.return_value = {"id": "req-1"}
await svc.on_request(msg, MagicMock(), None)
assert len(calls) == 1
assert calls[0]["user"] == "alice"
assert calls[0]["config"] == {"style": "pun"}
assert calls[0]["arguments"] == {"topic": "cats"}
@pytest.mark.asyncio
async def test_on_request_empty_user_defaults_to_trustgraph(self):
"""Empty user field should default to 'trustgraph'."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
received_user = None
async def capture_invoke(user, config, arguments):
nonlocal received_user
received_user = user
return "ok"
svc.invoke = capture_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
msg.properties.return_value = {"id": "req-2"}
await svc.on_request(msg, MagicMock(), None)
assert received_user == "trustgraph"
@pytest.mark.asyncio
async def test_on_request_string_response_sent_directly(self):
"""String return from invoke → response field is the string."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def string_invoke(user, config, arguments):
return "hello world"
svc.invoke = string_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "r1"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert isinstance(sent, ToolServiceResponse)
assert sent.response == "hello world"
assert sent.end_of_stream is True
assert sent.error is None
@pytest.mark.asyncio
async def test_on_request_dict_response_json_encoded(self):
"""Dict return from invoke → response field is JSON-encoded."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def dict_invoke(user, config, arguments):
return {"result": 42}
svc.invoke = dict_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "r2"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert json.loads(sent.response) == {"result": 42}
@pytest.mark.asyncio
async def test_on_request_error_sends_error_response(self):
"""Exception in invoke → error response sent."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def failing_invoke(user, config, arguments):
raise ValueError("bad input")
svc.invoke = failing_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "r3"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert sent.error is not None
assert sent.error.type == "tool-service-error"
assert "bad input" in sent.error.message
assert sent.response == ""
@pytest.mark.asyncio
async def test_on_request_too_many_requests_propagates(self):
"""TooManyRequests should propagate (not caught as error response)."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def rate_limited_invoke(user, config, arguments):
raise TooManyRequests("rate limited")
svc.invoke = rate_limited_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "r4"}
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), None)
@pytest.mark.asyncio
async def test_on_request_preserves_message_id(self):
"""Response should include the original message id in properties."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def ok_invoke(user, config, arguments):
return "ok"
svc.invoke = ok_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "unique-42"}
await svc.on_request(msg, MagicMock(), None)
props = svc.producer.send.call_args[1]["properties"]
assert props["id"] == "unique-42"
# ---------------------------------------------------------------------------
# ToolService (flow-based) tests
# ---------------------------------------------------------------------------
class TestToolServiceOnRequest:
@pytest.mark.asyncio
async def test_string_response_sent_as_text(self):
"""String return from invoke_tool → ToolResponse.text is set."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
return "tool result"
svc.invoke_tool = mock_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_response_pub = AsyncMock()
flow = MagicMock()
flow.name = "test-flow"
def flow_callable(name):
if name == "response":
return mock_response_pub
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
msg.properties.return_value = {"id": "t1"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert isinstance(sent, ToolResponse)
assert sent.text == "tool result"
assert sent.object is None
@pytest.mark.asyncio
async def test_dict_response_sent_as_json_object(self):
"""Dict return from invoke_tool → ToolResponse.object is JSON."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_response_pub = AsyncMock()
flow = MagicMock()
def flow_callable(name):
if name == "response":
return mock_response_pub
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t2"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert sent.text is None
assert json.loads(sent.object) == {"data": [1, 2, 3]}
@pytest.mark.asyncio
async def test_error_sends_error_response(self):
"""Exception in invoke_tool → error response via flow producer."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def failing_invoke(name, params):
raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke
mock_response_pub = AsyncMock()
flow = MagicMock()
def flow_callable(name):
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t3"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert sent.error is not None
assert sent.error.type == "tool-error"
assert "tool broke" in sent.error.message
@pytest.mark.asyncio
async def test_too_many_requests_propagates(self):
"""TooManyRequests should propagate from ToolService.on_request."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def rate_limited(name, params):
raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t4"}
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
flow.name = "test-flow"
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow)
@pytest.mark.asyncio
async def test_parameters_json_parsed(self):
"""Parameters should be JSON-parsed before passing to invoke_tool."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
received = {}
async def capture_invoke(name, params):
received["name"] = name
received["params"] = params
return "ok"
svc.invoke_tool = capture_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow.producer = {"response": mock_pub}
flow.name = "f"
msg = MagicMock()
msg.value.return_value = ToolRequest(
name="search",
parameters='{"query": "test", "limit": 10}',
)
msg.properties.return_value = {"id": "t5"}
await svc.on_request(msg, MagicMock(), flow)
assert received["name"] == "search"
assert received["params"] == {"query": "test", "limit": 10}
# ---------------------------------------------------------------------------
# ToolServiceClient tests
# ---------------------------------------------------------------------------
class TestToolServiceClientCall:
@pytest.mark.asyncio
async def test_call_sends_request_and_returns_response(self):
"""call() should send ToolServiceRequest and return response string."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="joke result", end_of_stream=True,
))
result = await client.call(
user="alice",
config={"style": "pun"},
arguments={"topic": "cats"},
)
assert result == "joke result"
req = client.request.call_args[0][0]
assert isinstance(req, ToolServiceRequest)
assert req.user == "alice"
assert json.loads(req.config) == {"style": "pun"}
assert json.loads(req.arguments) == {"topic": "cats"}
@pytest.mark.asyncio
async def test_call_raises_on_error(self):
"""call() should raise RuntimeError when response has error."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=Error(type="tool-service-error", message="service down"),
response="",
))
with pytest.raises(RuntimeError, match="service down"):
await client.call(user="u", config={}, arguments={})
@pytest.mark.asyncio
async def test_call_empty_config_sends_empty_json(self):
"""Empty config/arguments should be sent as '{}'."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="u", config=None, arguments=None)
req = client.request.call_args[0][0]
assert req.config == "{}"
assert req.arguments == "{}"
@pytest.mark.asyncio
async def test_call_passes_timeout(self):
"""call() should forward timeout to underlying request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="u", config={}, arguments={}, timeout=30)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 30
class TestToolServiceClientStreaming:
@pytest.mark.asyncio
async def test_call_streaming_collects_chunks(self):
"""call_streaming should accumulate chunks and return full result."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
# Simulate streaming: request() calls recipient with each chunk
chunks = [
ToolServiceResponse(error=None, response="chunk1", end_of_stream=False),
ToolServiceResponse(error=None, response="chunk2", end_of_stream=True),
]
async def mock_request(req, timeout=600, recipient=None):
for chunk in chunks:
done = await recipient(chunk)
if done:
break
client.request = mock_request
received = []
async def callback(text):
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
)
assert result == "chunk1chunk2"
assert received == ["chunk1", "chunk2"]
@pytest.mark.asyncio
async def test_call_streaming_raises_on_error(self):
"""call_streaming should raise RuntimeError on error chunk."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
async def mock_request(req, timeout=600, recipient=None):
error_resp = ToolServiceResponse(
error=Error(type="tool-service-error", message="stream failed"),
response="",
end_of_stream=True,
)
await recipient(error_resp)
client.request = mock_request
with pytest.raises(RuntimeError, match="stream failed"):
await client.call_streaming(
user="u", config={}, arguments={},
callback=AsyncMock(),
)
@pytest.mark.asyncio
async def test_call_streaming_skips_empty_response(self):
"""Empty response chunks should not be added to result."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
chunks = [
ToolServiceResponse(error=None, response="", end_of_stream=False),
ToolServiceResponse(error=None, response="data", end_of_stream=True),
]
async def mock_request(req, timeout=600, recipient=None):
for chunk in chunks:
done = await recipient(chunk)
if done:
break
client.request = mock_request
received = []
async def callback(text):
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
)
# Empty response is falsy, so callback shouldn't be called for it
assert result == "data"
assert received == ["data"]
# ---------------------------------------------------------------------------
# Multi-tenancy
# ---------------------------------------------------------------------------
class TestMultiTenancy:
@pytest.mark.asyncio
async def test_user_propagated_to_invoke(self):
"""User from request should reach the invoke method."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test"
svc.producer = AsyncMock()
users_seen = []
async def tracking(user, config, arguments):
users_seen.append(user)
return "ok"
svc.invoke = tracking
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
for u in ["tenant-a", "tenant-b", "tenant-c"]:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user=u, config="{}", arguments="{}",
)
msg.properties.return_value = {"id": f"req-{u}"}
await svc.on_request(msg, MagicMock(), None)
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
@pytest.mark.asyncio
async def test_client_sends_user_in_request(self):
"""ToolServiceClient.call should include user in request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="isolated-tenant", config={}, arguments={})
req = client.request.call_args[0][0]
assert req.user == "isolated-tenant"

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,286 @@
"""
Tests for Consumer concurrency: TaskGroup-based concurrent message processing,
rate-limit retry with backpressure, and message acknowledgement.
"""
import asyncio
import time
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.base.consumer import Consumer
from trustgraph.exceptions import TooManyRequests
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_consumer(
concurrency=1,
handler=None,
rate_limit_retry_time=0.01,
rate_limit_timeout=1,
):
"""Create a Consumer with mocked infrastructure."""
taskgroup = MagicMock()
flow = MagicMock()
backend = MagicMock()
schema = MagicMock()
handler = handler or AsyncMock()
consumer = Consumer(
taskgroup=taskgroup,
flow=flow,
backend=backend,
topic="test-topic",
subscriber="test-sub",
schema=schema,
handler=handler,
rate_limit_retry_time=rate_limit_retry_time,
rate_limit_timeout=rate_limit_timeout,
concurrency=concurrency,
)
return consumer
def _make_msg():
"""Create a mock Pulsar message."""
return MagicMock()
# ---------------------------------------------------------------------------
# Concurrency configuration tests
# ---------------------------------------------------------------------------
class TestConcurrencyConfiguration:
def test_default_concurrency_is_1(self):
consumer = _make_consumer()
assert consumer.concurrency == 1
def test_custom_concurrency(self):
consumer = _make_consumer(concurrency=10)
assert consumer.concurrency == 10
def test_concurrency_stored(self):
for n in [1, 5, 20, 100]:
consumer = _make_consumer(concurrency=n)
assert consumer.concurrency == n
class TestTaskGroupConcurrency:
@pytest.mark.asyncio
async def test_creates_n_concurrent_tasks(self):
"""consumer_run should create exactly N concurrent consume_from_queue tasks."""
concurrency = 5
consumer = _make_consumer(concurrency=concurrency)
# Track how many consume_from_queue calls are made
call_count = 0
original_running = True
async def mock_consume():
nonlocal call_count
call_count += 1
# Wait a bit to let all tasks start, then signal stop
await asyncio.sleep(0.05)
consumer.running = False
consumer.consume_from_queue = mock_consume
# Mock the backend.create_consumer
consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
# Run consumer_run - it will create TaskGroup with N tasks
consumer.running = True
await consumer.consumer_run()
assert call_count == concurrency
@pytest.mark.asyncio
async def test_single_concurrency_creates_one_task(self):
"""With concurrency=1, only one consume_from_queue task is created."""
consumer = _make_consumer(concurrency=1)
call_count = 0
async def mock_consume():
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
consumer.running = False
consumer.consume_from_queue = mock_consume
consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
consumer.running = True
await consumer.consumer_run()
assert call_count == 1
# ---------------------------------------------------------------------------
# Rate-limit retry tests
# ---------------------------------------------------------------------------
class TestRateLimitRetry:
@pytest.mark.asyncio
async def test_rate_limit_retries_then_succeeds(self):
"""TooManyRequests should cause retry, then succeed on next attempt."""
call_count = 0
async def handler_with_retry(msg, consumer_ref, flow):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TooManyRequests("rate limited")
# Second call succeeds
consumer = _make_consumer(
handler=handler_with_retry,
rate_limit_retry_time=0.01,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
assert call_count == 2
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@pytest.mark.asyncio
async def test_rate_limit_timeout_negative_acks(self):
"""If rate limit retries exhaust the timeout, message is negative-acked."""
async def always_rate_limited(msg, consumer_ref, flow):
raise TooManyRequests("rate limited")
consumer = _make_consumer(
handler=always_rate_limited,
rate_limit_retry_time=0.01,
rate_limit_timeout=0.05,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
consumer.consumer.acknowledge.assert_not_called()
@pytest.mark.asyncio
async def test_non_rate_limit_error_negative_acks_immediately(self):
"""Non-TooManyRequests errors should negative-ack immediately (no retry)."""
call_count = 0
async def failing_handler(msg, consumer_ref, flow):
nonlocal call_count
call_count += 1
raise ValueError("bad data")
consumer = _make_consumer(handler=failing_handler)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
assert call_count == 1
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
@pytest.mark.asyncio
async def test_successful_message_acknowledged(self):
"""Successfully processed messages are acknowledged."""
consumer = _make_consumer(handler=AsyncMock())
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
# ---------------------------------------------------------------------------
# Metrics integration
# ---------------------------------------------------------------------------
class TestMetricsIntegration:
@pytest.mark.asyncio
async def test_success_metric_on_success(self):
consumer = _make_consumer(handler=AsyncMock())
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
mock_metrics.record_time.return_value.__enter__ = MagicMock()
mock_metrics.record_time.return_value.__exit__ = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
mock_metrics.process.assert_called_once_with("success")
@pytest.mark.asyncio
async def test_error_metric_on_failure(self):
async def failing(msg, c, f):
raise ValueError("fail")
consumer = _make_consumer(handler=failing)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
mock_metrics.process.assert_called_once_with("error")
@pytest.mark.asyncio
async def test_rate_limit_metric_on_too_many_requests(self):
call_count = 0
async def handler(msg, c, f):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TooManyRequests("limited")
consumer = _make_consumer(
handler=handler,
rate_limit_retry_time=0.01,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
mock_metrics.record_time.return_value.__enter__ = MagicMock()
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
mock_metrics.rate_limit.assert_called_once()
# ---------------------------------------------------------------------------
# Stop / running flag
# ---------------------------------------------------------------------------
class TestStopBehaviour:
@pytest.mark.asyncio
async def test_stop_sets_running_false(self):
consumer = _make_consumer()
consumer.running = True
await consumer.stop()
assert consumer.running is False
def test_initial_running_state(self):
consumer = _make_consumer()
assert consumer.running is True

View file

@ -0,0 +1,136 @@
"""
Tests for MessageDispatcher semaphore-based concurrency enforcement.
Verifies that the dispatcher limits concurrent message processing to
max_workers via asyncio.Semaphore.
"""
import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestSemaphoreEnforcement:
@pytest.mark.asyncio
async def test_semaphore_limits_concurrent_processing(self):
"""Only max_workers messages should be processed concurrently."""
max_workers = 2
dispatcher = MessageDispatcher(max_workers=max_workers)
concurrent_count = 0
max_concurrent = 0
processing_event = asyncio.Event()
async def slow_process(message):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.05)
concurrent_count -= 1
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = slow_process
# Launch more tasks than max_workers
messages = [
{"id": f"msg-{i}", "service": "test", "request": {}}
for i in range(5)
]
tasks = [
asyncio.create_task(dispatcher.handle_message(m))
for m in messages
]
await asyncio.gather(*tasks)
# At no point should more than max_workers have been active
assert max_concurrent <= max_workers
@pytest.mark.asyncio
async def test_semaphore_value_matches_max_workers(self):
for n in [1, 5, 20]:
dispatcher = MessageDispatcher(max_workers=n)
assert dispatcher.semaphore._value == n
@pytest.mark.asyncio
async def test_active_tasks_tracked(self):
"""Active tasks should be added/removed during processing."""
dispatcher = MessageDispatcher(max_workers=5)
task_was_tracked = False
original_process = dispatcher._process_message
async def tracking_process(message):
nonlocal task_was_tracked
# During processing, our task should be in active_tasks
if len(dispatcher.active_tasks) > 0:
task_was_tracked = True
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = tracking_process
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
)
assert task_was_tracked
# After completion, task should be discarded
assert len(dispatcher.active_tasks) == 0
@pytest.mark.asyncio
async def test_semaphore_released_on_error(self):
"""Semaphore should be released even if processing raises."""
dispatcher = MessageDispatcher(max_workers=2)
async def failing_process(message):
raise RuntimeError("process failed")
dispatcher._process_message = failing_process
# Should not deadlock — semaphore must be released on error
with pytest.raises(RuntimeError):
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
)
# Semaphore should be back at max
assert dispatcher.semaphore._value == 2
@pytest.mark.asyncio
async def test_single_worker_serializes_processing(self):
"""With max_workers=1, messages are processed one at a time."""
dispatcher = MessageDispatcher(max_workers=1)
order = []
async def ordered_process(message):
msg_id = message["id"]
order.append(f"start-{msg_id}")
await asyncio.sleep(0.02)
order.append(f"end-{msg_id}")
return {"id": msg_id, "response": {"ok": True}}
dispatcher._process_message = ordered_process
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
await asyncio.gather(*tasks)
# With semaphore=1, each message should complete before next starts
# Check that no two "start" entries appear without an intervening "end"
active = 0
max_active = 0
for event in order:
if event.startswith("start"):
active += 1
max_active = max(max_active, active)
elif event.startswith("end"):
active -= 1
assert max_active == 1

View file

@ -0,0 +1,268 @@
"""
Tests for Graph RAG concurrent query execution.
Covers: execute_batch_triple_queries concurrent task spawning,
exception handling in gather, and result aggregation.
"""
import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_query(
triples_client=None,
entity_limit=50,
triple_limit=30,
max_subgraph_size=1000,
max_path_length=2,
):
"""Create a Query object with mocked rag dependencies."""
rag = MagicMock()
rag.triples_client = triples_client or AsyncMock()
rag.label_cache = LRUCacheWithTTL()
query = Query(
rag=rag,
user="test-user",
collection="test-collection",
verbose=False,
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
return query
def _make_triple(s, p, o):
"""Create a simple mock triple."""
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestBatchTripleQueries:
@pytest.mark.asyncio
async def test_three_queries_per_entity(self):
"""Each entity should generate 3 concurrent queries (s, p, o positions)."""
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[])
query = _make_query(triples_client=client)
entities = ["entity-1"]
await query.execute_batch_triple_queries(entities, limit_per_entity=10)
assert client.query_stream.call_count == 3
@pytest.mark.asyncio
async def test_multiple_entities_multiply_queries(self):
"""N entities should produce N*3 concurrent queries."""
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[])
query = _make_query(triples_client=client)
entities = ["e1", "e2", "e3"]
await query.execute_batch_triple_queries(entities, limit_per_entity=10)
assert client.query_stream.call_count == 9 # 3 * 3
@pytest.mark.asyncio
async def test_queries_executed_concurrently(self):
"""All queries should run concurrently via asyncio.gather."""
concurrent_count = 0
max_concurrent = 0
async def tracking_query(**kwargs):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.02)
concurrent_count -= 1
return []
client = AsyncMock()
client.query_stream = tracking_query
query = _make_query(triples_client=client)
entities = ["e1", "e2", "e3"]
await query.execute_batch_triple_queries(entities, limit_per_entity=5)
# All 9 queries should have run concurrently
assert max_concurrent == 9
@pytest.mark.asyncio
async def test_results_aggregated(self):
"""Results from all queries should be combined into a single list."""
triple_a = _make_triple("a", "p", "b")
triple_b = _make_triple("c", "p", "d")
call_count = 0
async def alternating_results(**kwargs):
nonlocal call_count
call_count += 1
if call_count % 2 == 0:
return [triple_a]
return [triple_b]
client = AsyncMock()
client.query_stream = alternating_results
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# 3 queries, alternating results
assert len(result) == 3
@pytest.mark.asyncio
async def test_exception_in_one_query_does_not_block_others(self):
"""If one query raises, other results are still collected."""
good_triple = _make_triple("a", "p", "b")
call_count = 0
async def mixed_results(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 2:
raise RuntimeError("query failed")
return [good_triple]
client = AsyncMock()
client.query_stream = mixed_results
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# 3 queries: 2 succeed, 1 fails → 2 triples
assert len(result) == 2
@pytest.mark.asyncio
async def test_none_results_filtered(self):
"""None results from queries should be filtered out."""
call_count = 0
async def sometimes_none(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return None
return [_make_triple("a", "p", "b")]
client = AsyncMock()
client.query_stream = sometimes_none
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# 3 queries: 1 returns None, 2 return triples
assert len(result) == 2
@pytest.mark.asyncio
async def test_empty_entities_no_queries(self):
"""Empty entity list should produce no queries."""
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[])
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries([], limit_per_entity=10)
assert result == []
client.query_stream.assert_not_called()
@pytest.mark.asyncio
async def test_query_params_correct(self):
"""Each query should use correct s/p/o positions and params."""
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[])
query = _make_query(triples_client=client)
entities = ["ent-1"]
await query.execute_batch_triple_queries(entities, limit_per_entity=15)
calls = client.query_stream.call_args_list
assert len(calls) == 3
# First call: s=entity, p=None, o=None
assert calls[0].kwargs["s"] == "ent-1"
assert calls[0].kwargs["p"] is None
assert calls[0].kwargs["o"] is None
assert calls[0].kwargs["limit"] == 15
assert calls[0].kwargs["user"] == "test-user"
assert calls[0].kwargs["collection"] == "test-collection"
assert calls[0].kwargs["batch_size"] == 20
# Second call: s=None, p=entity, o=None
assert calls[1].kwargs["s"] is None
assert calls[1].kwargs["p"] == "ent-1"
assert calls[1].kwargs["o"] is None
# Third call: s=None, p=None, o=entity
assert calls[2].kwargs["s"] is None
assert calls[2].kwargs["p"] is None
assert calls[2].kwargs["o"] == "ent-1"
class TestLRUCacheWithTTL:
def test_put_and_get(self):
cache = LRUCacheWithTTL(max_size=10, ttl=60)
cache.put("key1", "value1")
assert cache.get("key1") == "value1"
def test_get_missing_returns_none(self):
cache = LRUCacheWithTTL()
assert cache.get("nonexistent") is None
def test_max_size_eviction(self):
cache = LRUCacheWithTTL(max_size=2, ttl=60)
cache.put("a", 1)
cache.put("b", 2)
cache.put("c", 3) # Should evict "a"
assert cache.get("a") is None
assert cache.get("b") == 2
assert cache.get("c") == 3
def test_lru_order(self):
cache = LRUCacheWithTTL(max_size=2, ttl=60)
cache.put("a", 1)
cache.put("b", 2)
cache.get("a") # Access "a" — now "b" is LRU
cache.put("c", 3) # Should evict "b"
assert cache.get("a") == 1
assert cache.get("b") is None
assert cache.get("c") == 3
def test_ttl_expiration(self):
cache = LRUCacheWithTTL(max_size=10, ttl=0) # TTL=0 means instant expiry
cache.put("key", "value")
# With TTL=0, any time check > 0 means expired
import time
time.sleep(0.01)
assert cache.get("key") is None
def test_update_existing_key(self):
cache = LRUCacheWithTTL(max_size=10, ttl=60)
cache.put("key", "v1")
cache.put("key", "v2")
assert cache.get("key") == "v2"

View file

@ -0,0 +1,441 @@
"""
Tests for entity-centric KG write amplification, delete collection batching,
in-partition filtering, and term type metadata round-trips.
Complements test_entity_centric_kg.py with deeper verification of the
2-table schema mechanics.
"""
import pytest
from unittest.mock import MagicMock, patch, call
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_cassandra():
"""Provide mocked Cassandra cluster, session, and BatchStatement."""
with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cls, \
patch('trustgraph.direct.cassandra_kg.BatchStatement') as mock_batch_cls:
mock_cluster = MagicMock()
mock_session = MagicMock()
mock_cluster.connect.return_value = mock_session
mock_cls.return_value = mock_cluster
# Track batch.add calls per batch instance
batches = []
def make_batch():
batch = MagicMock()
batch._adds = []
original_add = batch.add
def tracking_add(stmt, params):
batch._adds.append((stmt, params))
batch.add = tracking_add
batches.append(batch)
return batch
mock_batch_cls.side_effect = make_batch
yield {
"cluster_cls": mock_cls,
"cluster": mock_cluster,
"session": mock_session,
"batch_cls": mock_batch_cls,
"batches": batches,
}
@pytest.fixture
def entity_kg(mock_cassandra):
"""Create an EntityCentricKnowledgeGraph with mocked Cassandra."""
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_ks')
return kg, mock_cassandra
# ---------------------------------------------------------------------------
# Write amplification: row count verification
# ---------------------------------------------------------------------------
class TestWriteAmplification:
def test_uri_object_produces_4_entity_rows_plus_collection(self, entity_kg):
"""URI object → S + P + O + G-if-non-default entity rows + 1 collection row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/Alice',
p='http://ex.org/knows',
o='http://ex.org/Bob',
g='http://ex.org/g1',
otype='u',
)
# Should be exactly one batch
assert len(ctx["batches"]) == 1
batch = ctx["batches"][0]
# 4 entity rows (S, P, O, G) + 1 collection row = 5
assert len(batch._adds) == 5
# Check roles assigned
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'S' in roles
assert 'P' in roles
assert 'O' in roles
assert 'G' in roles
def test_literal_object_produces_3_entity_rows(self, entity_kg):
"""Literal object → S + P entity rows (no O row) + collection row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/Alice',
p='http://ex.org/name',
o='Alice Smith',
g=None, # default graph
otype='l',
)
batch = ctx["batches"][0]
# S + P entity rows + 1 collection = 3 (no O row for literal, no G for default)
assert len(batch._adds) == 3
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'S' in roles
assert 'P' in roles
assert 'O' not in roles
assert 'G' not in roles
def test_triple_otype_gets_object_entity_row(self, entity_kg):
"""otype='t' (quoted triple) → object gets entity row like URI."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='{"s":{},"p":{},"o":{}}',
g=None,
otype='t',
)
batch = ctx["batches"][0]
# S + P + O entity rows + collection = 4 (no G for default graph)
assert len(batch._adds) == 4
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'O' in roles
def test_default_graph_no_g_row(self, entity_kg):
"""Default graph (g=None) → no G entity row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='http://ex.org/o',
g=None,
otype='u',
)
batch = ctx["batches"][0]
# S + P + O entity rows + collection = 4 (no G)
assert len(batch._adds) == 4
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'G' not in roles
def test_non_default_graph_gets_g_row(self, entity_kg):
"""Non-default graph → gets G entity row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='http://ex.org/o',
g='http://ex.org/graph1',
otype='u',
)
batch = ctx["batches"][0]
# S + P + O + G entity rows + collection = 5
assert len(batch._adds) == 5
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'G' in roles
def test_dtype_and_lang_passed_to_all_rows(self, entity_kg):
"""dtype and lang should be stored in every entity row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/label',
o='thing',
g=None,
otype='l',
dtype='xsd:string',
lang='en',
)
batch = ctx["batches"][0]
# Check entity rows carry dtype and lang
for _, params in batch._adds:
if len(params) == 10:
# Entity row: (collection, entity, role, p, otype, s, o, d, dtype, lang)
assert params[8] == 'xsd:string'
assert params[9] == 'en'
# ---------------------------------------------------------------------------
# In-partition filtering: get_os, get_spo
# ---------------------------------------------------------------------------
class TestInPartitionFiltering:
def test_get_os_filters_by_object(self, entity_kg):
"""get_os should filter results by matching object value."""
kg, ctx = entity_kg
# Simulate rows returned from subject partition (all have same s)
mock_rows = [
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
d='', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
MagicMock(p='http://ex.org/likes', o='http://ex.org/Charlie',
d='', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_os('col', 'http://ex.org/Bob', 'http://ex.org/Alice')
# Only the Bob row should pass the filter
assert len(results) == 1
assert results[0].o == 'http://ex.org/Bob'
assert results[0].p == 'http://ex.org/knows'
def test_get_os_returns_empty_when_no_match(self, entity_kg):
"""get_os should return empty list when object doesn't match any row."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
d='', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_os('col', 'http://ex.org/Charlie', 'http://ex.org/Alice')
assert len(results) == 0
def test_get_spo_filters_by_object(self, entity_kg):
"""get_spo should filter results by matching object value."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(o='http://ex.org/Bob', d='', otype='u', dtype='', lang=''),
MagicMock(o='http://ex.org/Charlie', d='', otype='u', dtype='', lang=''),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_spo(
'col', 'http://ex.org/Alice', 'http://ex.org/knows',
'http://ex.org/Bob',
)
assert len(results) == 1
assert results[0].o == 'http://ex.org/Bob'
def test_get_os_with_graph_filter(self, entity_kg):
"""get_os with specific graph should filter both object and graph."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
d='http://ex.org/g1', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
d='http://ex.org/g2', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_os(
'col', 'http://ex.org/Bob', 'http://ex.org/Alice',
g='http://ex.org/g1',
)
assert len(results) == 1
assert results[0].g == 'http://ex.org/g1'
# ---------------------------------------------------------------------------
# Delete collection batching
# ---------------------------------------------------------------------------
class TestDeleteCollectionBatching:
def test_extracts_unique_entities_from_quads(self, entity_kg):
"""delete_collection should extract s, p, and URI o as entities."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/knows',
o='http://ex.org/B', otype='u', dtype='', lang=''),
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name',
o='Alice', otype='l', dtype='', lang=''),
]
ctx["session"].execute.return_value = mock_rows
ctx["batches"].clear()
kg.delete_collection('col')
# Unique entities: A, knows, B, name (literal 'Alice' excluded)
# The batches should include entity partition deletes
all_adds = []
for batch in ctx["batches"]:
all_adds.extend(batch._adds)
# We expect entity deletes + collection row deletes + metadata delete
# Just verify the function completes and calls execute
assert ctx["session"].execute.called
def test_literal_objects_not_treated_as_entities(self, entity_kg):
"""Literal objects (otype='l') should not get entity partition deletes."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name',
o='Alice', otype='l', dtype='', lang=''),
]
ctx["session"].execute.return_value = mock_rows
ctx["batches"].clear()
kg.delete_collection('col')
# Entity partition deletes should only include A and name, not Alice
entity_deletes = []
for batch in ctx["batches"]:
for _, params in batch._adds:
if len(params) == 2: # delete_entity_partition takes (collection, entity)
entity_deletes.append(params[1])
assert 'http://ex.org/A' in entity_deletes
assert 'http://ex.org/name' in entity_deletes
assert 'Alice' not in entity_deletes
def test_non_default_graph_treated_as_entity(self, entity_kg):
"""Non-default graphs should get entity partition deletes."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(d='http://ex.org/g1', s='http://ex.org/A',
p='http://ex.org/p', o='http://ex.org/B',
otype='u', dtype='', lang=''),
]
ctx["session"].execute.return_value = mock_rows
ctx["batches"].clear()
kg.delete_collection('col')
entity_deletes = []
for batch in ctx["batches"]:
for _, params in batch._adds:
if len(params) == 2:
entity_deletes.append(params[1])
assert 'http://ex.org/g1' in entity_deletes
def test_empty_collection_delete_completes(self, entity_kg):
"""Deleting an empty collection should not error."""
kg, ctx = entity_kg
ctx["session"].execute.return_value = []
ctx["batches"].clear()
# Should not raise
kg.delete_collection('empty-col')
# ---------------------------------------------------------------------------
# Term type metadata round-trip
# ---------------------------------------------------------------------------
class TestTermTypeMetadata:
def test_query_results_include_otype(self, entity_kg):
"""Query results should include otype from Cassandra rows."""
kg, ctx = entity_kg
from trustgraph.direct.cassandra_kg import QuadResult
mock_rows = [
MagicMock(p='http://ex.org/name', o='Alice',
d='', otype='l', dtype='xsd:string', lang='en',
s='http://ex.org/Alice'),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_s('col', 'http://ex.org/Alice')
assert len(results) == 1
assert results[0].otype == 'l'
assert results[0].dtype == 'xsd:string'
assert results[0].lang == 'en'
def test_auto_detect_otype_uri(self, entity_kg):
"""Auto-detect should classify http:// as URI."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='http://ex.org/o',
)
batch = ctx["batches"][0]
# Check otype in entity rows (position 4)
for _, params in batch._adds:
if len(params) == 10:
assert params[4] == 'u'
def test_auto_detect_otype_literal(self, entity_kg):
"""Auto-detect should classify non-http:// as literal."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='plain text',
)
batch = ctx["batches"][0]
for _, params in batch._adds:
if len(params) == 10:
assert params[4] == 'l'

View file

@ -0,0 +1,164 @@
"""
Tests for document embeddings processor single-chunk embedding via batch API.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.embeddings.document_embeddings.embeddings import Processor
from trustgraph.schema import (
Chunk, DocumentEmbeddings, ChunkEmbeddings,
EmbeddingsRequest, EmbeddingsResponse, Metadata,
)
@pytest.fixture
def processor():
return Processor(
taskgroup=AsyncMock(),
id="test-doc-embeddings",
)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
msg = MagicMock()
msg.value.return_value = value
return msg
class TestDocumentEmbeddingsProcessor:
@pytest.mark.asyncio
async def test_sends_single_text_as_list(self, processor):
"""Document embeddings should wrap single chunk in a list for the API."""
msg = _make_chunk_message("test chunk text")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.1, 0.2, 0.3]]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
# Should send EmbeddingsRequest with texts=[chunk]
mock_request.assert_called_once()
req = mock_request.call_args[0][0]
assert isinstance(req, EmbeddingsRequest)
assert req.texts == ["test chunk text"]
@pytest.mark.asyncio
async def test_extracts_first_vector(self, processor):
"""Should use vectors[0] from the response."""
msg = _make_chunk_message("chunk")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[1.0, 2.0, 3.0]]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert isinstance(result, DocumentEmbeddings)
assert len(result.chunks) == 1
assert result.chunks[0].vector == [1.0, 2.0, 3.0]
@pytest.mark.asyncio
async def test_empty_vectors_response(self, processor):
"""Should handle empty vectors response gracefully."""
msg = _make_chunk_message("chunk")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.chunks[0].vector == []
@pytest.mark.asyncio
async def test_chunk_id_is_document_id(self, processor):
"""ChunkEmbeddings should use document_id as chunk_id."""
msg = _make_chunk_message(doc_id="my-doc-42")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.chunks[0].chunk_id == "my-doc-42"
@pytest.mark.asyncio
async def test_metadata_preserved(self, processor):
"""Output should carry the original metadata."""
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.metadata.user == "alice"
assert result.metadata.collection == "reports"
assert result.metadata.id == "d1"
@pytest.mark.asyncio
async def test_error_propagates(self, processor):
"""Embedding errors should propagate for retry."""
msg = _make_chunk_message()
mock_request = AsyncMock(side_effect=RuntimeError("service down"))
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
return MagicMock()
with pytest.raises(RuntimeError, match="service down"):
await processor.on_message(msg, MagicMock(), flow)

View file

@ -0,0 +1,109 @@
"""
Tests for EmbeddingsClient the client interface for batch embeddings.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.base.embeddings_client import EmbeddingsClient
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
class TestEmbeddingsClient:
@pytest.mark.asyncio
async def test_embed_sends_request_and_returns_vectors(self):
"""embed() should send an EmbeddingsRequest and return vectors."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None,
vectors=[[0.1, 0.2], [0.3, 0.4]],
))
result = await client.embed(texts=["hello", "world"])
assert result == [[0.1, 0.2], [0.3, 0.4]]
client.request.assert_called_once()
req = client.request.call_args[0][0]
assert isinstance(req, EmbeddingsRequest)
assert req.texts == ["hello", "world"]
@pytest.mark.asyncio
async def test_embed_single_text(self):
"""embed() should work with a single text."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None,
vectors=[[1.0, 2.0, 3.0]],
))
result = await client.embed(texts=["single"])
assert result == [[1.0, 2.0, 3.0]]
@pytest.mark.asyncio
async def test_embed_raises_on_error_response(self):
"""embed() should raise RuntimeError when response contains an error."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=Error(type="embeddings-error", message="model not found"),
vectors=[],
))
with pytest.raises(RuntimeError, match="model not found"):
await client.embed(texts=["test"])
@pytest.mark.asyncio
async def test_embed_passes_timeout(self):
"""embed() should pass timeout to the underlying request."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]],
))
await client.embed(texts=["test"], timeout=60)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 60
@pytest.mark.asyncio
async def test_embed_default_timeout(self):
"""embed() should use 300s default timeout."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]],
))
await client.embed(texts=["test"])
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 300
@pytest.mark.asyncio
async def test_embed_empty_texts(self):
"""embed() with empty list should still make the request."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[],
))
result = await client.embed(texts=[])
assert result == []
@pytest.mark.asyncio
async def test_embed_large_batch(self):
"""embed() should handle large batches."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
n = 100
vectors = [[float(i)] for i in range(n)]
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=vectors,
))
texts = [f"text {i}" for i in range(n)]
result = await client.embed(texts=texts)
assert len(result) == n
req = client.request.call_args[0][0]
assert len(req.texts) == n

View file

@ -0,0 +1,135 @@
"""
Tests for EmbeddingsService.on_request the request handler that dispatches
to on_embeddings and sends responses.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.base import EmbeddingsService
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
from trustgraph.exceptions import TooManyRequests
class StubEmbeddingsService(EmbeddingsService):
"""Minimal concrete implementation for testing on_request."""
def __init__(self, embed_result=None, embed_error=None):
# Skip super().__init__ to avoid taskgroup/registration
self.embed_result = embed_result or [[0.1, 0.2]]
self.embed_error = embed_error
async def on_embeddings(self, texts, model=None):
if self.embed_error:
raise self.embed_error
return self.embed_result
def _make_msg(texts, msg_id="req-1"):
request = EmbeddingsRequest(texts=texts)
msg = MagicMock()
msg.value.return_value = request
msg.properties.return_value = {"id": msg_id}
return msg
def _make_flow(model="test-model"):
mock_response_producer = AsyncMock()
mock_flow = MagicMock()
def flow_callable(name):
if name == "model":
return model
if name == "response":
return mock_response_producer
return MagicMock()
flow_callable.producer = {"response": mock_response_producer}
return flow_callable, mock_response_producer
class TestEmbeddingsServiceOnRequest:
@pytest.mark.asyncio
async def test_successful_request(self):
"""on_request should call on_embeddings and send response."""
service = StubEmbeddingsService(embed_result=[[0.1, 0.2], [0.3, 0.4]])
msg = _make_msg(["hello", "world"], msg_id="r1")
flow, mock_response = _make_flow(model="my-model")
await service.on_request(msg, MagicMock(), flow)
mock_response.send.assert_called_once()
resp = mock_response.send.call_args[0][0]
assert isinstance(resp, EmbeddingsResponse)
assert resp.error is None
assert resp.vectors == [[0.1, 0.2], [0.3, 0.4]]
# Check id is passed through
props = mock_response.send.call_args[1]["properties"]
assert props["id"] == "r1"
@pytest.mark.asyncio
async def test_passes_model_from_flow(self):
"""on_request should pass model parameter from flow to on_embeddings."""
calls = []
class TrackingService(EmbeddingsService):
def __init__(self):
pass
async def on_embeddings(self, texts, model=None):
calls.append({"texts": texts, "model": model})
return [[0.0]]
service = TrackingService()
msg = _make_msg(["test"])
flow, _ = _make_flow(model="custom-model-v2")
await service.on_request(msg, MagicMock(), flow)
assert len(calls) == 1
assert calls[0]["model"] == "custom-model-v2"
assert calls[0]["texts"] == ["test"]
@pytest.mark.asyncio
async def test_error_sends_error_response(self):
"""Non-rate-limit errors should send an error response."""
service = StubEmbeddingsService(
embed_error=ValueError("dimension mismatch")
)
msg = _make_msg(["test"], msg_id="r2")
flow, mock_response = _make_flow()
await service.on_request(msg, MagicMock(), flow)
mock_response.send.assert_called_once()
resp = mock_response.send.call_args[0][0]
assert resp.error is not None
assert resp.error.type == "embeddings-error"
assert "dimension mismatch" in resp.error.message
assert resp.vectors == []
@pytest.mark.asyncio
async def test_rate_limit_propagates(self):
"""TooManyRequests should propagate (not caught as error response)."""
service = StubEmbeddingsService(
embed_error=TooManyRequests("rate limited")
)
msg = _make_msg(["test"])
flow, _ = _make_flow()
with pytest.raises(TooManyRequests):
await service.on_request(msg, MagicMock(), flow)
@pytest.mark.asyncio
async def test_message_id_preserved(self):
"""The request message id should be forwarded in the response properties."""
service = StubEmbeddingsService()
msg = _make_msg(["test"], msg_id="unique-id-42")
flow, mock_response = _make_flow()
await service.on_request(msg, MagicMock(), flow)
props = mock_response.send.call_args[1]["properties"]
assert props["id"] == "unique-id-42"

View file

@ -0,0 +1,233 @@
"""
Tests for graph embeddings processor batch embedding of entity contexts.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.embeddings.graph_embeddings.embeddings import Processor
from trustgraph.schema import (
EntityContexts, EntityEmbeddings, GraphEmbeddings,
Term, IRI, Metadata,
)
@pytest.fixture
def processor():
return Processor(
taskgroup=AsyncMock(),
id="test-graph-embeddings",
batch_size=3,
)
def _make_entity_context(name, context, chunk_id="chunk-1"):
"""Create an entity context for testing."""
entity = Term(type=IRI, iri=f"urn:entity:{name}")
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
value = EntityContexts(metadata=metadata, entities=entities)
msg = MagicMock()
msg.value.return_value = value
return msg
class TestGraphEmbeddingsInit:
def test_default_batch_size(self):
p = Processor(taskgroup=AsyncMock(), id="test")
assert p.batch_size == 5
def test_custom_batch_size(self):
p = Processor(taskgroup=AsyncMock(), id="test", batch_size=20)
assert p.batch_size == 20
class TestGraphEmbeddingsBatchProcessing:
@pytest.mark.asyncio
async def test_single_batch_call_for_all_entities(self, processor):
"""All entity contexts should be embedded in a single API call."""
entities = [
_make_entity_context("Alice", "Alice is a person"),
_make_entity_context("Bob", "Bob is a developer"),
_make_entity_context("Acme", "Acme is a company"),
]
msg = _make_message(entities)
mock_embed = AsyncMock(return_value=[
[0.1, 0.2], [0.3, 0.4], [0.5, 0.6],
])
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
# Single batch call with all three texts
mock_embed.assert_called_once_with(
texts=["Alice is a person", "Bob is a developer", "Acme is a company"]
)
@pytest.mark.asyncio
async def test_vectors_paired_with_correct_entities(self, processor):
"""Each vector should be paired with its corresponding entity."""
entities = [
_make_entity_context("Alice", "ctx-A", chunk_id="c1"),
_make_entity_context("Bob", "ctx-B", chunk_id="c2"),
]
msg = _make_message(entities)
vectors = [[1.0, 2.0], [3.0, 4.0]]
mock_embed = AsyncMock(return_value=vectors)
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
# With batch_size=3, all 2 entities fit in one output message
mock_output.send.assert_called_once()
result = mock_output.send.call_args[0][0]
assert isinstance(result, GraphEmbeddings)
assert len(result.entities) == 2
assert result.entities[0].vector == [1.0, 2.0]
assert result.entities[0].entity.iri == "urn:entity:Alice"
assert result.entities[0].chunk_id == "c1"
assert result.entities[1].vector == [3.0, 4.0]
assert result.entities[1].entity.iri == "urn:entity:Bob"
@pytest.mark.asyncio
async def test_output_batching(self, processor):
"""Output should be split into batches of batch_size."""
# batch_size=3, 7 entities -> 3 output messages (3+3+1)
entities = [
_make_entity_context(f"E{i}", f"context {i}")
for i in range(7)
]
msg = _make_message(entities)
vectors = [[float(i)] for i in range(7)]
mock_embed = AsyncMock(return_value=vectors)
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
assert mock_output.send.call_count == 3
# First batch has 3 entities
batch1 = mock_output.send.call_args_list[0][0][0]
assert len(batch1.entities) == 3
# Second batch has 3 entities
batch2 = mock_output.send.call_args_list[1][0][0]
assert len(batch2.entities) == 3
# Third batch has 1 entity
batch3 = mock_output.send.call_args_list[2][0][0]
assert len(batch3.entities) == 1
@pytest.mark.asyncio
async def test_output_batches_preserve_metadata(self, processor):
"""Each output batch should carry the original metadata."""
entities = [
_make_entity_context(f"E{i}", f"ctx {i}")
for i in range(5)
]
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
for call in mock_output.send.call_args_list:
result = call[0][0]
assert result.metadata.id == "doc-42"
assert result.metadata.user == "alice"
assert result.metadata.collection == "main"
@pytest.mark.asyncio
async def test_single_entity(self, processor):
"""Single entity should work with one embed call and one output."""
entities = [_make_entity_context("Solo", "solo context")]
msg = _make_message(entities)
mock_embed = AsyncMock(return_value=[[1.0, 2.0, 3.0]])
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
mock_embed.assert_called_once_with(texts=["solo context"])
mock_output.send.assert_called_once()
@pytest.mark.asyncio
async def test_embed_error_propagates(self, processor):
"""Embedding service errors should propagate for retry."""
entities = [_make_entity_context("E", "ctx")]
msg = _make_message(entities)
mock_embed = AsyncMock(side_effect=RuntimeError("embedding failed"))
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
return MagicMock()
with pytest.raises(RuntimeError, match="embedding failed"):
await processor.on_message(msg, MagicMock(), flow)
@pytest.mark.asyncio
async def test_exact_batch_size(self, processor):
"""When entity count equals batch_size, exactly one output message."""
entities = [
_make_entity_context(f"E{i}", f"ctx {i}")
for i in range(3) # batch_size=3
]
msg = _make_message(entities)
mock_embed = AsyncMock(return_value=[[0.0]] * 3)
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
mock_output.send.assert_called_once()
assert len(mock_output.send.call_args[0][0].entities) == 3

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,407 @@
"""
Tests for streaming triple and entity context batching in the definitions
KG extractor.
Covers: triples batch splitting, entity context batch splitting,
metadata preservation, provenance, and empty/null filtering.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size,
)
from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_processor(triples_batch_size=default_triples_batch_size,
entity_batch_size=default_entity_batch_size):
proc = Processor.__new__(Processor)
proc.triples_batch_size = triples_batch_size
proc.entity_batch_size = entity_batch_size
return proc
def _make_defn(entity, definition):
return {"entity": entity, "definition": definition}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
)
msg = MagicMock()
msg.value.return_value = chunk
return msg
def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_definitions = AsyncMock(
return_value=prompt_result
)
def flow(name):
if name == "prompt-request":
return mock_prompt_client
if name == "triples":
return mock_triples_pub
if name == "entity-contexts":
return mock_ecs_pub
if name == "llm-model":
return llm_model
if name == "ontology":
return ontology_uri
return MagicMock()
return flow, mock_triples_pub, mock_ecs_pub, mock_prompt_client
def _sent_triples(mock_pub):
return [call.args[0] for call in mock_pub.send.call_args_list]
def _sent_ecs(mock_pub):
return [call.args[0] for call in mock_pub.send.call_args_list]
def _all_triples_flat(mock_pub):
result = []
for triples_msg in _sent_triples(mock_pub):
result.extend(triples_msg.triples)
return result
def _all_entities_flat(mock_pub):
result = []
for ecs_msg in _sent_ecs(mock_pub):
result.extend(ecs_msg.entities)
return result
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestDefaults:
def test_default_triples_batch_size(self):
assert default_triples_batch_size == 50
def test_default_entity_batch_size(self):
assert default_entity_batch_size == 5
class TestTriplesBatching:
@pytest.mark.asyncio
async def test_single_batch_when_under_limit(self):
proc = _make_processor(triples_batch_size=100)
defs = [_make_defn("Cat", "A feline animal")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 1
@pytest.mark.asyncio
async def test_multiple_triples_batches(self):
proc = _make_processor(triples_batch_size=2)
defs = [
_make_defn("Cat", "A feline"),
_make_defn("Dog", "A canine"),
]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
# 2 defs → 2 labels + 2 definitions = 4 triples + provenance
# With batch_size=2, should produce multiple batches
assert triples_pub.send.call_count > 1
@pytest.mark.asyncio
async def test_triples_batch_sizes_within_limit(self):
batch_size = 3
proc = _make_processor(triples_batch_size=batch_size)
defs = [
_make_defn("A", "def A"),
_make_defn("B", "def B"),
_make_defn("C", "def C"),
]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
for triples_msg in _sent_triples(triples_pub):
assert len(triples_msg.triples) <= batch_size
class TestEntityContextBatching:
@pytest.mark.asyncio
async def test_single_entity_batch_when_under_limit(self):
proc = _make_processor(entity_batch_size=100)
defs = [_make_defn("Cat", "A feline")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
# 1 def → 2 entity contexts (name + definition)
assert ecs_pub.send.call_count == 1
@pytest.mark.asyncio
async def test_multiple_entity_batches(self):
proc = _make_processor(entity_batch_size=2)
defs = [
_make_defn("Cat", "A feline"),
_make_defn("Dog", "A canine"),
]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
# 2 defs → 4 entity contexts, batch_size=2 → 2 batches
assert ecs_pub.send.call_count == 2
@pytest.mark.asyncio
async def test_entity_batch_sizes_within_limit(self):
batch_size = 3
proc = _make_processor(entity_batch_size=batch_size)
defs = [
_make_defn("A", "def A"),
_make_defn("B", "def B"),
_make_defn("C", "def C"),
]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
for ecs_msg in _sent_ecs(ecs_pub):
assert len(ecs_msg.entities) <= batch_size
@pytest.mark.asyncio
async def test_entity_contexts_have_name_and_definition(self):
"""Each definition produces 2 entity contexts: name and definition."""
proc = _make_processor(entity_batch_size=100)
defs = [_make_defn("Cat", "A feline animal")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
entities = _all_entities_flat(ecs_pub)
assert len(entities) == 2
contexts = {e.context for e in entities}
assert "Cat" in contexts
assert "A feline animal" in contexts
class TestMetadataPreservation:
@pytest.mark.asyncio
async def test_triples_metadata(self):
proc = _make_processor(triples_batch_size=2)
defs = [_make_defn("X", "def X")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
for triples_msg in _sent_triples(triples_pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"
@pytest.mark.asyncio
async def test_entity_contexts_metadata(self):
proc = _make_processor(entity_batch_size=1)
defs = [_make_defn("X", "def X")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-2", root="r-2",
user="u-2", collection="coll-2",
)
await proc.on_message(msg, MagicMock(), flow)
for ecs_msg in _sent_ecs(ecs_pub):
assert ecs_msg.metadata.id == "c-2"
assert ecs_msg.metadata.root == "r-2"
class TestEmptyAndNullFiltering:
@pytest.mark.asyncio
async def test_empty_entity_skipped(self):
proc = _make_processor()
defs = [
_make_defn("", "some definition"),
_make_defn("Valid", "a valid definition"),
]
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(triples_pub)
all_e = _all_entities_flat(ecs_pub)
# Only "Valid" should be present
entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")}
assert any("valid" in iri for iri in entity_iris)
assert len(all_e) == 2 # name + definition for "Valid" only
@pytest.mark.asyncio
async def test_empty_definition_skipped(self):
proc = _make_processor()
defs = [
_make_defn("Entity", ""),
_make_defn("Good", "good definition"),
]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(triples_pub)
entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")}
assert any("good" in iri for iri in entity_iris)
# "Entity" with empty def should have been skipped
assert not any("entity" in iri and "good" not in iri for iri in entity_iris)
@pytest.mark.asyncio
async def test_none_fields_skipped(self):
proc = _make_processor()
defs = [
_make_defn(None, "some definition"),
_make_defn("Entity", None),
]
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
@pytest.mark.asyncio
async def test_all_filtered_no_output(self):
proc = _make_processor()
defs = [_make_defn("", ""), _make_defn(None, None)]
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
@pytest.mark.asyncio
async def test_empty_prompt_response(self):
proc = _make_processor()
flow, triples_pub, ecs_pub, _ = _make_flow([])
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
class TestProvenanceInclusion:
@pytest.mark.asyncio
async def test_provenance_triples_present(self):
proc = _make_processor(triples_batch_size=200)
defs = [_make_defn("Cat", "A feline")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(triples_pub)
# 1 def → 1 label + 1 definition = 2 content triples
# Provenance adds more
assert len(all_t) > 2
class TestErrorHandling:
@pytest.mark.asyncio
async def test_prompt_error_caught(self):
proc = _make_processor()
flow, triples_pub, ecs_pub, prompt = _make_flow([])
prompt.extract_definitions = AsyncMock(
side_effect=RuntimeError("LLM error")
)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
@pytest.mark.asyncio
async def test_non_list_response_caught(self):
proc = _make_processor()
flow, triples_pub, ecs_pub, prompt = _make_flow("not a list")
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
class TestDocumentIdProvenance:
@pytest.mark.asyncio
async def test_document_id_used_for_chunk_id(self):
"""When document_id is set, entity contexts should use it as chunk_id."""
proc = _make_processor(entity_batch_size=100)
defs = [_make_defn("Cat", "A feline")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text", document_id="doc-123")
await proc.on_message(msg, MagicMock(), flow)
entities = _all_entities_flat(ecs_pub)
for e in entities:
assert e.chunk_id == "doc-123"
@pytest.mark.asyncio
async def test_metadata_id_fallback_for_chunk_id(self):
"""When document_id is empty, metadata.id is used as chunk_id."""
proc = _make_processor(entity_batch_size=100)
defs = [_make_defn("Cat", "A feline")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text", meta_id="chunk-42", document_id="")
await proc.on_message(msg, MagicMock(), flow)
entities = _all_entities_flat(ecs_pub)
for e in entities:
assert e.chunk_id == "chunk-42"

View file

@ -0,0 +1,408 @@
"""
Tests for streaming triple batching in the relationships KG extractor.
Covers: batch size configuration, output splitting, metadata preservation,
provenance inclusion, empty/null filtering, and error propagation.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.relationships.extract import (
Processor, default_triples_batch_size,
)
from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_processor(triples_batch_size=default_triples_batch_size):
"""Create a Processor without triggering FlowProcessor.__init__."""
proc = Processor.__new__(Processor)
proc.triples_batch_size = triples_batch_size
return proc
def _make_rel(subject, predicate, obj, object_entity=True):
"""Build a relationship dict as returned by the prompt client."""
return {
"subject": subject,
"predicate": predicate,
"object": obj,
"object-entity": object_entity,
}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
"""Build a mock message wrapping a Chunk."""
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
)
msg = MagicMock()
msg.value.return_value = chunk
return msg
def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
"""Build a mock flow callable that provides prompt client, triples
producer, and parameter specs."""
mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock(
return_value=prompt_result
)
def flow(name):
if name == "prompt-request":
return mock_prompt_client
if name == "triples":
return mock_triples_pub
if name == "llm-model":
return llm_model
if name == "ontology":
return ontology_uri
return MagicMock()
return flow, mock_triples_pub, mock_prompt_client
def _sent_triples(mock_pub):
"""Collect all Triples objects sent to a mock publisher."""
return [call.args[0] for call in mock_pub.send.call_args_list]
def _all_triples_flat(mock_pub):
"""Flatten all batches into one list of Triple objects."""
result = []
for triples_msg in _sent_triples(mock_pub):
result.extend(triples_msg.triples)
return result
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestDefaultBatchSize:
def test_default_is_50(self):
assert default_triples_batch_size == 50
def test_processor_uses_default(self):
proc = _make_processor()
assert proc.triples_batch_size == 50
class TestBatchSplitting:
@pytest.mark.asyncio
async def test_single_batch_when_under_limit(self):
"""Few triples → single send call."""
proc = _make_processor(triples_batch_size=50)
rels = [_make_rel("A", "knows", "B")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("some text")
await proc.on_message(msg, MagicMock(), flow)
# One relationship produces: rel triple + 3 labels + provenance
# All should fit in one batch of 50
assert pub.send.call_count == 1
@pytest.mark.asyncio
async def test_multiple_batches_with_small_batch_size(self):
"""With batch_size=3 and many triples, multiple batches are sent."""
proc = _make_processor(triples_batch_size=3)
# 2 relationships → 2 rel triples + 6 labels = 8 triples + provenance
rels = [
_make_rel("A", "knows", "B"),
_make_rel("C", "likes", "D"),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("some text")
await proc.on_message(msg, MagicMock(), flow)
# Should have more than one batch
assert pub.send.call_count > 1
@pytest.mark.asyncio
async def test_batch_sizes_respect_limit(self):
"""No batch should exceed the configured batch size."""
batch_size = 3
proc = _make_processor(triples_batch_size=batch_size)
rels = [
_make_rel("A", "knows", "B"),
_make_rel("C", "likes", "D"),
_make_rel("E", "has", "F"),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
for triples_msg in _sent_triples(pub):
assert len(triples_msg.triples) <= batch_size
@pytest.mark.asyncio
async def test_all_triples_present_across_batches(self):
"""Total triples across batches equals expected count."""
proc = _make_processor(triples_batch_size=2)
# 1 relationship with object-entity=True → 1 rel + 3 labels = 4 triples
# + provenance triples
rels = [_make_rel("A", "knows", "B", object_entity=True)]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# At minimum: 1 rel + 3 labels = 4 content triples
assert len(all_t) >= 4
@pytest.mark.asyncio
async def test_custom_batch_size(self):
"""Processor respects custom triples_batch_size parameter."""
proc = _make_processor(triples_batch_size=100)
assert proc.triples_batch_size == 100
class TestMetadataPreservation:
@pytest.mark.asyncio
async def test_metadata_forwarded_to_all_batches(self):
"""Every batch should carry the original chunk metadata."""
proc = _make_processor(triples_batch_size=2)
rels = [_make_rel("X", "rel", "Y")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
for triples_msg in _sent_triples(pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"
class TestRelationshipTriples:
@pytest.mark.asyncio
async def test_entity_object_produces_iri(self):
"""object-entity=True → object is an IRI, with label triple."""
proc = _make_processor(triples_batch_size=200)
rels = [_make_rel("Alice", "knows", "Bob", object_entity=True)]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# Find the relationship triple (not a label)
rel_triples = [
t for t in all_t
if t.o.type == IRI and "bob" in t.o.iri
]
assert len(rel_triples) >= 1
@pytest.mark.asyncio
async def test_literal_object_produces_literal(self):
"""object-entity=False → object is a LITERAL, no label for object."""
proc = _make_processor(triples_batch_size=200)
rels = [_make_rel("Alice", "age", "30", object_entity=False)]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# Find the relationship triple with literal object
lit_triples = [
t for t in all_t
if t.o.type == LITERAL and t.o.value == "30"
]
assert len(lit_triples) == 1
@pytest.mark.asyncio
async def test_labels_emitted_for_subject_and_predicate(self):
"""Every relationship should produce label triples for s and p."""
proc = _make_processor(triples_batch_size=200)
rels = [_make_rel("Alice", "knows", "Bob")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
label_triples = [
t for t in all_t
if t.p.type == IRI and "label" in t.p.iri.lower()
]
labels = {t.o.value for t in label_triples}
assert "Alice" in labels
assert "knows" in labels
assert "Bob" in labels # object-entity default is True
class TestEmptyAndNullFiltering:
@pytest.mark.asyncio
async def test_empty_string_fields_skipped(self):
"""Relationships with empty string s/p/o are skipped."""
proc = _make_processor(triples_batch_size=200)
rels = [
_make_rel("", "knows", "Bob"),
_make_rel("Alice", "", "Bob"),
_make_rel("Alice", "knows", ""),
_make_rel("Good", "triple", "Here"),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# Only the "Good triple Here" relationship should produce content triples
rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri}
assert any("good" in iri for iri in rel_iris)
assert not any("alice" in iri for iri in rel_iris)
@pytest.mark.asyncio
async def test_none_fields_skipped(self):
"""Relationships with None s/p/o are skipped."""
proc = _make_processor(triples_batch_size=200)
rels = [
_make_rel(None, "knows", "Bob"),
_make_rel("Alice", None, "Bob"),
_make_rel("Alice", "knows", None),
_make_rel("Valid", "rel", "Here"),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri}
assert any("valid" in iri for iri in rel_iris)
assert not any("alice" in iri for iri in rel_iris)
@pytest.mark.asyncio
async def test_all_filtered_produces_no_output(self):
"""If all relationships are empty/null, nothing is emitted."""
proc = _make_processor(triples_batch_size=200)
rels = [
_make_rel("", "", ""),
_make_rel(None, None, None),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
@pytest.mark.asyncio
async def test_empty_prompt_response_produces_no_output(self):
"""Empty relationship list from prompt → no triples emitted."""
proc = _make_processor()
flow, pub, _ = _make_flow([])
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
class TestProvenanceInclusion:
@pytest.mark.asyncio
async def test_provenance_triples_present(self):
"""Extracted relationships should include provenance triples."""
proc = _make_processor(triples_batch_size=200)
rels = [_make_rel("A", "knows", "B")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# Provenance triples use GRAPH_SOURCE graph context
# They contain terms referencing prov: namespace or subgraph URIs
# We just check that total count > 4 (1 rel + 3 labels)
assert len(all_t) > 4
@pytest.mark.asyncio
async def test_no_provenance_when_no_extracted_triples(self):
"""Empty relationships → no provenance generated."""
proc = _make_processor()
flow, pub, _ = _make_flow([_make_rel("", "x", "y")])
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
class TestErrorPropagation:
@pytest.mark.asyncio
async def test_prompt_error_is_caught(self):
"""Errors from the prompt client are caught (logged, not raised)."""
proc = _make_processor()
flow, pub, prompt = _make_flow([])
prompt.extract_relationships = AsyncMock(
side_effect=RuntimeError("LLM unavailable")
)
msg = _make_chunk_msg("text")
# The outer try/except in on_message catches and logs
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
@pytest.mark.asyncio
async def test_non_list_response_is_caught(self):
"""Non-list prompt response triggers RuntimeError, caught by handler."""
proc = _make_processor()
flow, pub, prompt = _make_flow("not a list")
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
class TestToUri:
def test_spaces_replaced_with_hyphens(self):
proc = _make_processor()
uri = proc.to_uri("hello world")
assert "hello-world" in uri
def test_lowercased(self):
proc = _make_processor()
uri = proc.to_uri("Hello World")
assert "hello-world" in uri
def test_special_chars_encoded(self):
proc = _make_processor()
# urllib.parse.quote keeps / as safe by default
uri = proc.to_uri("a/b")
assert "a/b" in uri
# Characters like spaces are encoded (handled via replace → hyphen)
uri2 = proc.to_uri("hello world")
assert " " not in uri2

View file

View file

@ -0,0 +1,716 @@
"""
Tests for librarian chunked upload operations:
begin_upload, upload_chunk, complete_upload, abort_upload, get_upload_status,
list_uploads, and stream_document.
"""
import base64
import json
import math
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.librarian.librarian import Librarian, DEFAULT_CHUNK_SIZE
from trustgraph.exceptions import RequestError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_librarian(min_chunk_size=1):
"""Create a Librarian with mocked blob_store and table_store."""
lib = Librarian.__new__(Librarian)
lib.blob_store = MagicMock()
lib.table_store = AsyncMock()
lib.load_document = AsyncMock()
lib.min_chunk_size = min_chunk_size
return lib
def _make_doc_metadata(
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
):
meta = MagicMock()
meta.id = doc_id
meta.kind = kind
meta.user = user
meta.title = title
meta.time = 1700000000
meta.comments = ""
meta.tags = []
return meta
def _make_begin_request(
doc_id="doc-1", kind="application/pdf", user="alice",
total_size=10_000_000, chunk_size=0
):
req = MagicMock()
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
req.total_size = total_size
req.chunk_size = chunk_size
return req
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
req = MagicMock()
req.upload_id = upload_id
req.chunk_index = chunk_index
req.user = user
req.content = base64.b64encode(content)
return req
def _make_session(
user="alice", total_chunks=5, chunk_size=2_000_000,
total_size=10_000_000, chunks_received=None, object_id="obj-1",
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
):
if chunks_received is None:
chunks_received = {}
if document_metadata is None:
document_metadata = json.dumps({
"id": document_id, "kind": "application/pdf",
"user": user, "title": "Test", "time": 1700000000,
"comments": "", "tags": [],
})
return {
"user": user,
"total_chunks": total_chunks,
"chunk_size": chunk_size,
"total_size": total_size,
"chunks_received": chunks_received,
"object_id": object_id,
"s3_upload_id": s3_upload_id,
"document_metadata": document_metadata,
"document_id": document_id,
}
# ---------------------------------------------------------------------------
# begin_upload
# ---------------------------------------------------------------------------
class TestBeginUpload:
@pytest.mark.asyncio
async def test_creates_session(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-upload-id"
req = _make_begin_request(total_size=10_000_000)
resp = await lib.begin_upload(req)
assert resp.error is None
assert resp.upload_id is not None
assert resp.total_chunks == math.ceil(10_000_000 / DEFAULT_CHUNK_SIZE)
assert resp.chunk_size == DEFAULT_CHUNK_SIZE
@pytest.mark.asyncio
async def test_custom_chunk_size(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(total_size=10_000, chunk_size=3000)
resp = await lib.begin_upload(req)
assert resp.chunk_size == 3000
assert resp.total_chunks == math.ceil(10_000 / 3000)
@pytest.mark.asyncio
async def test_rejects_invalid_kind(self):
lib = _make_librarian()
req = _make_begin_request(kind="image/png")
with pytest.raises(RequestError, match="Invalid document kind"):
await lib.begin_upload(req)
@pytest.mark.asyncio
async def test_rejects_duplicate_document(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = True
req = _make_begin_request()
with pytest.raises(RequestError, match="already exists"):
await lib.begin_upload(req)
@pytest.mark.asyncio
async def test_rejects_zero_size(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
req = _make_begin_request(total_size=0)
with pytest.raises(RequestError, match="positive"):
await lib.begin_upload(req)
@pytest.mark.asyncio
async def test_rejects_chunk_below_minimum(self):
lib = _make_librarian(min_chunk_size=1024)
lib.table_store.document_exists.return_value = False
req = _make_begin_request(total_size=10_000, chunk_size=512)
with pytest.raises(RequestError, match="below minimum"):
await lib.begin_upload(req)
@pytest.mark.asyncio
async def test_calls_s3_create_multipart(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(kind="application/pdf")
await lib.begin_upload(req)
lib.blob_store.create_multipart_upload.assert_called_once()
# create_multipart_upload(object_id, kind) — positional args
args = lib.blob_store.create_multipart_upload.call_args[0]
assert args[1] == "application/pdf"
@pytest.mark.asyncio
async def test_stores_session_in_cassandra(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(total_size=5_000_000)
resp = await lib.begin_upload(req)
lib.table_store.create_upload_session.assert_called_once()
kwargs = lib.table_store.create_upload_session.call_args[1]
assert kwargs["upload_id"] == resp.upload_id
assert kwargs["total_size"] == 5_000_000
assert kwargs["total_chunks"] == resp.total_chunks
@pytest.mark.asyncio
async def test_accepts_text_plain(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(kind="text/plain", total_size=1000)
resp = await lib.begin_upload(req)
assert resp.error is None
# ---------------------------------------------------------------------------
# upload_chunk
# ---------------------------------------------------------------------------
class TestUploadChunk:
@pytest.mark.asyncio
async def test_successful_chunk_upload(self):
lib = _make_librarian()
session = _make_session(total_chunks=5, chunks_received={})
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "etag-1"
req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data")
resp = await lib.upload_chunk(req)
assert resp.error is None
assert resp.chunk_index == 0
assert resp.total_chunks == 5
# The chunk is added to the dict (len=1), then +1 applied => 2
assert resp.chunks_received == 2
@pytest.mark.asyncio
async def test_s3_part_number_is_1_indexed(self):
lib = _make_librarian()
session = _make_session()
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "etag"
req = _make_upload_chunk_request(chunk_index=0)
await lib.upload_chunk(req)
kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part
@pytest.mark.asyncio
async def test_chunk_index_3_becomes_part_4(self):
lib = _make_librarian()
session = _make_session()
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "etag"
req = _make_upload_chunk_request(chunk_index=3)
await lib.upload_chunk(req)
kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["part_number"] == 4
@pytest.mark.asyncio
async def test_rejects_expired_session(self):
lib = _make_librarian()
lib.table_store.get_upload_session.return_value = None
req = _make_upload_chunk_request()
with pytest.raises(RequestError, match="not found"):
await lib.upload_chunk(req)
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(user="bob")
with pytest.raises(RequestError, match="Not authorized"):
await lib.upload_chunk(req)
@pytest.mark.asyncio
async def test_rejects_negative_chunk_index(self):
lib = _make_librarian()
session = _make_session(total_chunks=5)
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(chunk_index=-1)
with pytest.raises(RequestError, match="Invalid chunk index"):
await lib.upload_chunk(req)
@pytest.mark.asyncio
async def test_rejects_out_of_range_chunk_index(self):
lib = _make_librarian()
session = _make_session(total_chunks=5)
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(chunk_index=5)
with pytest.raises(RequestError, match="Invalid chunk index"):
await lib.upload_chunk(req)
@pytest.mark.asyncio
async def test_progress_tracking(self):
lib = _make_librarian()
session = _make_session(
total_chunks=4, chunk_size=1000, total_size=3500,
chunks_received={0: "e1", 1: "e2"},
)
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "e3"
req = _make_upload_chunk_request(chunk_index=2)
resp = await lib.upload_chunk(req)
# Dict gets chunk 2 added (len=3), then +1 => 4
assert resp.chunks_received == 4
assert resp.total_chunks == 4
assert resp.total_bytes == 3500
@pytest.mark.asyncio
async def test_bytes_capped_at_total_size(self):
"""bytes_received should not exceed total_size for the final chunk."""
lib = _make_librarian()
session = _make_session(
total_chunks=2, chunk_size=3000, total_size=5000,
chunks_received={0: "e1"},
)
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "e2"
req = _make_upload_chunk_request(chunk_index=1)
resp = await lib.upload_chunk(req)
# 3 chunks × 3000 = 9000 > 5000, so capped
assert resp.bytes_received <= 5000
@pytest.mark.asyncio
async def test_base64_decodes_content(self):
lib = _make_librarian()
session = _make_session()
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "etag"
raw = b"hello world binary data"
req = _make_upload_chunk_request(content=raw)
await lib.upload_chunk(req)
kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["data"] == raw
# ---------------------------------------------------------------------------
# complete_upload
# ---------------------------------------------------------------------------
class TestCompleteUpload:
@pytest.mark.asyncio
async def test_successful_completion(self):
lib = _make_librarian()
session = _make_session(
total_chunks=3,
chunks_received={0: "e1", 1: "e2", 2: "e3"},
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
resp = await lib.complete_upload(req)
assert resp.error is None
assert resp.document_id == "doc-1"
lib.blob_store.complete_multipart_upload.assert_called_once()
lib.table_store.add_document.assert_called_once()
lib.table_store.delete_upload_session.assert_called_once_with("up-1")
@pytest.mark.asyncio
async def test_parts_sorted_by_index(self):
lib = _make_librarian()
# Chunks received out of order
session = _make_session(
total_chunks=3,
chunks_received={2: "e3", 0: "e1", 1: "e2"},
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
await lib.complete_upload(req)
parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"]
part_numbers = [p[0] for p in parts]
assert part_numbers == [1, 2, 3] # Sorted, 1-indexed
@pytest.mark.asyncio
async def test_rejects_missing_chunks(self):
lib = _make_librarian()
session = _make_session(
total_chunks=3,
chunks_received={0: "e1", 2: "e3"}, # chunk 1 missing
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
with pytest.raises(RequestError, match="Missing chunks"):
await lib.complete_upload(req)
@pytest.mark.asyncio
async def test_rejects_expired_session(self):
lib = _make_librarian()
lib.table_store.get_upload_session.return_value = None
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.complete_upload(req)
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.complete_upload(req)
# ---------------------------------------------------------------------------
# abort_upload
# ---------------------------------------------------------------------------
class TestAbortUpload:
@pytest.mark.asyncio
async def test_aborts_and_cleans_up(self):
lib = _make_librarian()
session = _make_session()
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
resp = await lib.abort_upload(req)
assert resp.error is None
lib.blob_store.abort_multipart_upload.assert_called_once_with(
object_id="obj-1", upload_id="s3-up-1"
)
lib.table_store.delete_upload_session.assert_called_once_with("up-1")
@pytest.mark.asyncio
async def test_rejects_expired_session(self):
lib = _make_librarian()
lib.table_store.get_upload_session.return_value = None
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.abort_upload(req)
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.abort_upload(req)
# ---------------------------------------------------------------------------
# get_upload_status
# ---------------------------------------------------------------------------
class TestGetUploadStatus:
@pytest.mark.asyncio
async def test_in_progress_status(self):
lib = _make_librarian()
session = _make_session(
total_chunks=5, chunk_size=2000, total_size=10_000,
chunks_received={0: "e1", 2: "e3", 4: "e5"},
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
resp = await lib.get_upload_status(req)
assert resp.upload_state == "in-progress"
assert resp.chunks_received == 3
assert resp.total_chunks == 5
assert sorted(resp.received_chunks) == [0, 2, 4]
assert sorted(resp.missing_chunks) == [1, 3]
assert resp.total_bytes == 10_000
@pytest.mark.asyncio
async def test_expired_session(self):
lib = _make_librarian()
lib.table_store.get_upload_session.return_value = None
req = MagicMock()
req.upload_id = "up-expired"
req.user = "alice"
resp = await lib.get_upload_status(req)
assert resp.upload_state == "expired"
@pytest.mark.asyncio
async def test_all_chunks_received(self):
lib = _make_librarian()
session = _make_session(
total_chunks=3, chunk_size=1000, total_size=2500,
chunks_received={0: "e1", 1: "e2", 2: "e3"},
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
resp = await lib.get_upload_status(req)
assert resp.missing_chunks == []
assert resp.chunks_received == 3
# 3 * 1000 = 3000 > 2500, so capped
assert resp.bytes_received <= 2500
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.get_upload_status(req)
# ---------------------------------------------------------------------------
# stream_document
# ---------------------------------------------------------------------------
class TestStreamDocument:
@pytest.mark.asyncio
async def test_streams_chunks_with_progress(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=5000)
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
assert len(chunks) == 3 # ceil(5000/2000)
assert chunks[0].chunk_index == 0
assert chunks[0].total_chunks == 3
assert chunks[0].is_final is False
assert chunks[-1].is_final is True
assert chunks[-1].chunk_index == 2
@pytest.mark.asyncio
async def test_single_chunk_document(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=500)
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].bytes_received == 500
assert chunks[0].total_bytes == 500
@pytest.mark.asyncio
async def test_byte_ranges_correct(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=5000)
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
# Verify the byte ranges passed to get_range
calls = lib.blob_store.get_range.call_args_list
assert calls[0][0] == ("obj-1", 0, 2000)
assert calls[1][0] == ("obj-1", 2000, 2000)
assert calls[2][0] == ("obj-1", 4000, 1000) # Last chunk: 5000-4000
@pytest.mark.asyncio
async def test_default_chunk_size(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=2_000_000)
lib.blob_store.get_range = AsyncMock(return_value=b"x")
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 0 # Should use default 1MB
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
assert len(chunks) == 2 # ceil(2MB / 1MB)
@pytest.mark.asyncio
async def test_content_is_base64_encoded(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=100)
raw = b"hello world"
lib.blob_store.get_range = AsyncMock(return_value=raw)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 1000
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
assert chunks[0].content == base64.b64encode(raw)
@pytest.mark.asyncio
async def test_rejects_chunk_below_minimum(self):
lib = _make_librarian(min_chunk_size=1024)
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=5000)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 512
with pytest.raises(RequestError, match="below minimum"):
async for _ in lib.stream_document(req):
pass
# ---------------------------------------------------------------------------
# list_uploads
# ---------------------------------------------------------------------------
class TestListUploads:
@pytest.mark.asyncio
async def test_returns_sessions(self):
lib = _make_librarian()
lib.table_store.list_upload_sessions.return_value = [
{
"upload_id": "up-1",
"document_id": "doc-1",
"document_metadata": '{"id":"doc-1"}',
"total_size": 10000,
"chunk_size": 2000,
"total_chunks": 5,
"chunks_received": {0: "e1", 1: "e2"},
"created_at": "2024-01-01",
},
]
req = MagicMock()
req.user = "alice"
resp = await lib.list_uploads(req)
assert resp.error is None
assert len(resp.upload_sessions) == 1
assert resp.upload_sessions[0].upload_id == "up-1"
assert resp.upload_sessions[0].total_chunks == 5
@pytest.mark.asyncio
async def test_empty_uploads(self):
lib = _make_librarian()
lib.table_store.list_upload_sessions.return_value = []
req = MagicMock()
req.user = "alice"
resp = await lib.list_uploads(req)
assert resp.upload_sessions == []

View file

View file

@ -0,0 +1,324 @@
"""
Tests for agent provenance triple builder functions.
"""
import json
import pytest
from trustgraph.schema import Triple, Term, IRI, LITERAL
from trustgraph.provenance.agent import (
agent_session_triples,
agent_iteration_triples,
agent_final_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_AGENT_QUESTION,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
for t in triples:
if (t.s.iri == subject and t.p.iri == RDF_TYPE
and t.o.type == IRI and t.o.iri == rdf_type):
return True
return False
# ---------------------------------------------------------------------------
# agent_session_triples
# ---------------------------------------------------------------------------
class TestAgentSessionTriples:
SESSION_URI = "urn:trustgraph:agent:test-session"
def test_session_types(self):
triples = agent_session_triples(
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
)
assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY)
assert has_type(triples, self.SESSION_URI, TG_QUESTION)
assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION)
def test_session_query_text(self):
triples = agent_session_triples(
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
)
query = find_triple(triples, TG_QUERY, self.SESSION_URI)
assert query is not None
assert query.o.value == "What is X?"
def test_session_timestamp(self):
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-06-15T10:00:00Z"
)
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI)
assert ts is not None
assert ts.o.value == "2024-06-15T10:00:00Z"
def test_session_default_timestamp(self):
triples = agent_session_triples(self.SESSION_URI, "Q")
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI)
assert ts is not None
assert len(ts.o.value) > 0
def test_session_label(self):
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
)
label = find_triple(triples, RDFS_LABEL, self.SESSION_URI)
assert label is not None
assert label.o.value == "Agent Question"
def test_session_triple_count(self):
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
)
assert len(triples) == 6
# ---------------------------------------------------------------------------
# agent_iteration_triples
# ---------------------------------------------------------------------------
class TestAgentIterationTriples:
ITER_URI = "urn:trustgraph:agent:test-session/i1"
PARENT_URI = "urn:trustgraph:agent:test-session"
def test_iteration_types(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="thinking", action="search", observation="found it",
)
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
def test_iteration_derived_from_parent(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is not None
assert derived.o.iri == self.PARENT_URI
def test_iteration_label_includes_action(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="graph-rag-query",
)
label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
assert label is not None
assert "graph-rag-query" in label.o.value
def test_iteration_thought_inline(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="I need to search for info",
action="search",
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought is not None
assert thought.o.value == "I need to search for info"
def test_iteration_thought_document_preferred(self):
"""When thought_document_id is provided, inline thought is not stored."""
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="inline thought",
action="search",
thought_document_id="urn:doc:thought-1",
)
thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI)
assert thought_doc is not None
assert thought_doc.o.iri == "urn:doc:thought-1"
thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought_inline is None
def test_iteration_observation_inline(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
observation="Found 3 results",
)
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs is not None
assert obs.o.value == "Found 3 results"
def test_iteration_observation_document_preferred(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
observation="inline obs",
observation_document_id="urn:doc:obs-1",
)
obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI)
assert obs_doc is not None
assert obs_doc.o.iri == "urn:doc:obs-1"
obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_inline is None
def test_iteration_action_recorded(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="graph-rag-query",
)
action = find_triple(triples, TG_ACTION, self.ITER_URI)
assert action is not None
assert action.o.value == "graph-rag-query"
def test_iteration_arguments_json_encoded(self):
args = {"query": "test query", "limit": 10}
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
arguments=args,
)
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
assert arguments is not None
parsed = json.loads(arguments.o.value)
assert parsed == args
def test_iteration_default_arguments_empty_dict(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
)
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
assert arguments is not None
parsed = json.loads(arguments.o.value)
assert parsed == {}
def test_iteration_no_thought_or_observation(self):
"""Minimal iteration with just action — no thought or observation triples."""
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="noop",
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert thought is None
assert obs is None
def test_iteration_chaining(self):
"""Second iteration derives from first iteration, not session."""
iter1_uri = "urn:trustgraph:agent:sess/i1"
iter2_uri = "urn:trustgraph:agent:sess/i2"
triples1 = agent_iteration_triples(
iter1_uri, self.PARENT_URI, action="step1",
)
triples2 = agent_iteration_triples(
iter2_uri, iter1_uri, action="step2",
)
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
assert derived1.o.iri == self.PARENT_URI
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
assert derived2.o.iri == iter1_uri
# ---------------------------------------------------------------------------
# agent_final_triples
# ---------------------------------------------------------------------------
class TestAgentFinalTriples:
FINAL_URI = "urn:trustgraph:agent:test-session/final"
PARENT_URI = "urn:trustgraph:agent:test-session/i3"
def test_final_types(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
)
assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
def test_final_derived_from_parent(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is not None
assert derived.o.iri == self.PARENT_URI
def test_final_label(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
)
label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
assert label is not None
assert label.o.value == "Conclusion"
def test_final_inline_answer(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="The answer is 42"
)
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
assert answer is not None
assert answer.o.value == "The answer is 42"
def test_final_document_reference(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI,
document_id="urn:trustgraph:agent:sess/answer",
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is not None
assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
def test_final_document_takes_precedence(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI,
answer="inline",
document_id="urn:doc:123",
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is not None
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
assert answer is None
def test_final_no_answer_or_document(self):
triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI)
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert answer is None
assert doc is None
def test_final_derives_from_session_when_no_iterations(self):
"""When agent answers immediately, final derives from session."""
session_uri = "urn:trustgraph:agent:test-session"
triples = agent_final_triples(
self.FINAL_URI, session_uri, answer="direct answer"
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived.o.iri == session_uri

View file

@ -0,0 +1,563 @@
"""
Tests for the explainability API (entity parsing, wire format conversion,
and ExplainabilityClient).
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.api.explainability import (
EdgeSelection,
ExplainEntity,
Question,
Exploration,
Focus,
Synthesis,
Analysis,
Conclusion,
parse_edge_selection_triples,
extract_term_value,
wire_triples_to_tuples,
ExplainabilityClient,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
RDF_TYPE, RDFS_LABEL,
)
# ---------------------------------------------------------------------------
# Entity from_triples parsing
# ---------------------------------------------------------------------------
class TestExplainEntityFromTriples:
"""Test ExplainEntity.from_triples dispatches to correct subclass."""
def test_graphrag_question(self):
triples = [
("urn:q:1", RDF_TYPE, TG_QUESTION),
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:q:1", TG_QUERY, "What is AI?"),
("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"),
]
entity = ExplainEntity.from_triples("urn:q:1", triples)
assert isinstance(entity, Question)
assert entity.query == "What is AI?"
assert entity.timestamp == "2024-01-01T00:00:00Z"
assert entity.question_type == "graph-rag"
def test_docrag_question(self):
triples = [
("urn:q:2", RDF_TYPE, TG_QUESTION),
("urn:q:2", RDF_TYPE, TG_DOC_RAG_QUESTION),
("urn:q:2", TG_QUERY, "Find info"),
]
entity = ExplainEntity.from_triples("urn:q:2", triples)
assert isinstance(entity, Question)
assert entity.question_type == "document-rag"
def test_agent_question(self):
triples = [
("urn:q:3", RDF_TYPE, TG_QUESTION),
("urn:q:3", RDF_TYPE, TG_AGENT_QUESTION),
("urn:q:3", TG_QUERY, "Agent query"),
]
entity = ExplainEntity.from_triples("urn:q:3", triples)
assert isinstance(entity, Question)
assert entity.question_type == "agent"
def test_exploration(self):
triples = [
("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
("urn:exp:1", TG_EDGE_COUNT, "15"),
]
entity = ExplainEntity.from_triples("urn:exp:1", triples)
assert isinstance(entity, Exploration)
assert entity.edge_count == 15
def test_exploration_with_chunk_count(self):
triples = [
("urn:exp:2", RDF_TYPE, TG_EXPLORATION),
("urn:exp:2", TG_CHUNK_COUNT, "5"),
]
entity = ExplainEntity.from_triples("urn:exp:2", triples)
assert isinstance(entity, Exploration)
assert entity.chunk_count == 5
def test_exploration_invalid_count(self):
triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
("urn:exp:3", TG_EDGE_COUNT, "not-a-number"),
]
entity = ExplainEntity.from_triples("urn:exp:3", triples)
assert isinstance(entity, Exploration)
assert entity.edge_count == 0
def test_focus(self):
triples = [
("urn:foc:1", RDF_TYPE, TG_FOCUS),
("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:1"),
("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:2"),
]
entity = ExplainEntity.from_triples("urn:foc:1", triples)
assert isinstance(entity, Focus)
assert len(entity.selected_edge_uris) == 2
assert "urn:edge:1" in entity.selected_edge_uris
assert "urn:edge:2" in entity.selected_edge_uris
def test_synthesis_with_content(self):
triples = [
("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:1", TG_CONTENT, "The answer is 42"),
]
entity = ExplainEntity.from_triples("urn:syn:1", triples)
assert isinstance(entity, Synthesis)
assert entity.content == "The answer is 42"
assert entity.document_uri == ""
def test_synthesis_with_document(self):
triples = [
("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"),
]
entity = ExplainEntity.from_triples("urn:syn:2", triples)
assert isinstance(entity, Synthesis)
assert entity.document_uri == "urn:doc:answer-1"
def test_analysis(self):
triples = [
("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
("urn:ana:1", TG_THOUGHT, "I should search"),
("urn:ana:1", TG_ACTION, "graph-rag-query"),
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
("urn:ana:1", TG_OBSERVATION, "Found results"),
]
entity = ExplainEntity.from_triples("urn:ana:1", triples)
assert isinstance(entity, Analysis)
assert entity.thought == "I should search"
assert entity.action == "graph-rag-query"
assert entity.arguments == '{"query": "test"}'
assert entity.observation == "Found results"
def test_analysis_with_document_refs(self):
triples = [
("urn:ana:2", RDF_TYPE, TG_ANALYSIS),
("urn:ana:2", TG_ACTION, "search"),
("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"),
("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:2", triples)
assert isinstance(entity, Analysis)
assert entity.thought_document_uri == "urn:doc:thought-1"
assert entity.observation_document_uri == "urn:doc:obs-1"
def test_conclusion_with_answer(self):
triples = [
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
("urn:conc:1", TG_ANSWER, "The final answer"),
]
entity = ExplainEntity.from_triples("urn:conc:1", triples)
assert isinstance(entity, Conclusion)
assert entity.answer == "The final answer"
def test_conclusion_with_document(self):
triples = [
("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
("urn:conc:2", TG_DOCUMENT, "urn:doc:final"),
]
entity = ExplainEntity.from_triples("urn:conc:2", triples)
assert isinstance(entity, Conclusion)
assert entity.document_uri == "urn:doc:final"
def test_unknown_type(self):
triples = [
("urn:x:1", RDF_TYPE, "http://example.com/UnknownType"),
]
entity = ExplainEntity.from_triples("urn:x:1", triples)
assert isinstance(entity, ExplainEntity)
assert entity.entity_type == "unknown"
# ---------------------------------------------------------------------------
# parse_edge_selection_triples
# ---------------------------------------------------------------------------
class TestParseEdgeSelectionTriples:
def test_with_edge_and_reasoning(self):
triples = [
("urn:edge:1", TG_EDGE, {"s": "Alice", "p": "knows", "o": "Bob"}),
("urn:edge:1", TG_REASONING, "Alice and Bob are connected"),
]
result = parse_edge_selection_triples(triples)
assert isinstance(result, EdgeSelection)
assert result.uri == "urn:edge:1"
assert result.edge == {"s": "Alice", "p": "knows", "o": "Bob"}
assert result.reasoning == "Alice and Bob are connected"
def test_with_edge_only(self):
triples = [
("urn:edge:2", TG_EDGE, {"s": "A", "p": "r", "o": "B"}),
]
result = parse_edge_selection_triples(triples)
assert result.edge is not None
assert result.reasoning == ""
def test_with_reasoning_only(self):
triples = [
("urn:edge:3", TG_REASONING, "some reason"),
]
result = parse_edge_selection_triples(triples)
assert result.edge is None
assert result.reasoning == "some reason"
def test_empty_triples(self):
result = parse_edge_selection_triples([])
assert result.uri == ""
assert result.edge is None
assert result.reasoning == ""
def test_edge_must_be_dict(self):
"""Non-dict values for TG_EDGE should not be treated as edges."""
triples = [
("urn:edge:4", TG_EDGE, "not-a-dict"),
]
result = parse_edge_selection_triples(triples)
assert result.edge is None
# ---------------------------------------------------------------------------
# extract_term_value
# ---------------------------------------------------------------------------
class TestExtractTermValue:
def test_iri_short_format(self):
assert extract_term_value({"t": "i", "i": "urn:test"}) == "urn:test"
def test_iri_long_format(self):
assert extract_term_value({"type": "i", "iri": "urn:test"}) == "urn:test"
def test_literal_short_format(self):
assert extract_term_value({"t": "l", "v": "hello"}) == "hello"
def test_literal_long_format(self):
assert extract_term_value({"type": "l", "value": "hello"}) == "hello"
def test_quoted_triple(self):
term = {
"t": "t",
"tr": {
"s": {"t": "i", "i": "urn:s"},
"p": {"t": "i", "i": "urn:p"},
"o": {"t": "i", "i": "urn:o"},
}
}
result = extract_term_value(term)
assert result == {"s": "urn:s", "p": "urn:p", "o": "urn:o"}
def test_quoted_triple_long_format(self):
term = {
"type": "t",
"triple": {
"s": {"type": "i", "iri": "urn:s"},
"p": {"type": "i", "iri": "urn:p"},
"o": {"type": "l", "value": "val"},
}
}
result = extract_term_value(term)
assert result == {"s": "urn:s", "p": "urn:p", "o": "val"}
def test_unknown_type_fallback(self):
result = extract_term_value({"t": "x", "i": "urn:fallback"})
assert result == "urn:fallback"
# ---------------------------------------------------------------------------
# wire_triples_to_tuples
# ---------------------------------------------------------------------------
class TestWireTriplesToTuples:
def test_basic_conversion(self):
wire = [
{
"s": {"t": "i", "i": "urn:s1"},
"p": {"t": "i", "i": "urn:p1"},
"o": {"t": "l", "v": "value1"},
},
]
result = wire_triples_to_tuples(wire)
assert len(result) == 1
assert result[0] == ("urn:s1", "urn:p1", "value1")
def test_multiple_triples(self):
wire = [
{
"s": {"t": "i", "i": "urn:s1"},
"p": {"t": "i", "i": "urn:p1"},
"o": {"t": "l", "v": "v1"},
},
{
"s": {"t": "i", "i": "urn:s2"},
"p": {"t": "i", "i": "urn:p2"},
"o": {"t": "i", "i": "urn:o2"},
},
]
result = wire_triples_to_tuples(wire)
assert len(result) == 2
assert result[0] == ("urn:s1", "urn:p1", "v1")
assert result[1] == ("urn:s2", "urn:p2", "urn:o2")
def test_empty_list(self):
assert wire_triples_to_tuples([]) == []
def test_missing_fields(self):
wire = [{"s": {}, "p": {}, "o": {}}]
result = wire_triples_to_tuples(wire)
assert len(result) == 1
# ---------------------------------------------------------------------------
# ExplainabilityClient
# ---------------------------------------------------------------------------
def _make_wire_triples(tuples):
"""Convert (s, p, o) tuples to wire format for mocking."""
result = []
for s, p, o in tuples:
entry = {
"s": {"t": "i", "i": s},
"p": {"t": "i", "i": p},
}
if o.startswith("urn:") or o.startswith("http"):
entry["o"] = {"t": "i", "i": o}
else:
entry["o"] = {"t": "l", "v": o}
result.append(entry)
return result
class TestExplainabilityClientFetchEntity:
def test_fetch_question_entity(self):
wire = _make_wire_triples([
("urn:q:1", RDF_TYPE, TG_QUESTION),
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:q:1", TG_QUERY, "What is AI?"),
("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"),
])
mock_flow = MagicMock()
# Return same results twice for quiescence
mock_flow.triples_query.side_effect = [wire, wire]
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
entity = client.fetch_entity("urn:q:1", graph="urn:graph:retrieval")
assert isinstance(entity, Question)
assert entity.query == "What is AI?"
assert entity.question_type == "graph-rag"
def test_fetch_returns_none_when_no_data(self):
mock_flow = MagicMock()
mock_flow.triples_query.return_value = []
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
entity = client.fetch_entity("urn:nonexistent")
assert entity is None
def test_fetch_retries_on_empty_results(self):
wire = _make_wire_triples([
("urn:q:1", RDF_TYPE, TG_QUESTION),
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:q:1", TG_QUERY, "Q"),
])
mock_flow = MagicMock()
# Empty, then data, then same data (stable)
mock_flow.triples_query.side_effect = [[], wire, wire]
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
entity = client.fetch_entity("urn:q:1")
assert isinstance(entity, Question)
assert mock_flow.triples_query.call_count == 3
class TestExplainabilityClientResolveLabel:
def test_resolve_label_found(self):
mock_flow = MagicMock()
mock_flow.triples_query.return_value = _make_wire_triples([
("urn:entity:1", RDFS_LABEL, "Entity One"),
])
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
label = client.resolve_label("urn:entity:1")
assert label == "Entity One"
def test_resolve_label_not_found(self):
mock_flow = MagicMock()
mock_flow.triples_query.return_value = []
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
label = client.resolve_label("urn:entity:1")
assert label == "urn:entity:1"
def test_resolve_label_cached(self):
mock_flow = MagicMock()
mock_flow.triples_query.return_value = _make_wire_triples([
("urn:entity:1", RDFS_LABEL, "Entity One"),
])
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
client.resolve_label("urn:entity:1")
client.resolve_label("urn:entity:1")
# Only one query should be made
assert mock_flow.triples_query.call_count == 1
def test_resolve_label_non_uri(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.resolve_label("plain text") == "plain text"
assert client.resolve_label("") == ""
mock_flow.triples_query.assert_not_called()
def test_resolve_edge_labels(self):
mock_flow = MagicMock()
def mock_query(s=None, p=None, **kwargs):
labels = {
"urn:e:Alice": "Alice",
"urn:r:knows": "knows",
"urn:e:Bob": "Bob",
}
if s in labels:
return _make_wire_triples([(s, RDFS_LABEL, labels[s])])
return []
mock_flow.triples_query.side_effect = mock_query
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
s, p, o = client.resolve_edge_labels(
{"s": "urn:e:Alice", "p": "urn:r:knows", "o": "urn:e:Bob"}
)
assert s == "Alice"
assert p == "knows"
assert o == "Bob"
class TestExplainabilityClientContentFetching:
def test_fetch_synthesis_inline_content(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(uri="urn:syn:1", content="inline answer")
result = client.fetch_synthesis_content(synthesis, api=None)
assert result == "inline answer"
def test_fetch_synthesis_truncated_content(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
long_content = "x" * 20000
synthesis = Synthesis(uri="urn:syn:1", content=long_content)
result = client.fetch_synthesis_content(synthesis, api=None, max_content=100)
assert len(result) < 20000
assert result.endswith("... [truncated]")
def test_fetch_synthesis_from_librarian(self):
mock_flow = MagicMock()
mock_api = MagicMock()
mock_library = MagicMock()
mock_api.library.return_value = mock_library
mock_library.get_document_content.return_value = b"librarian content"
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(
uri="urn:syn:1",
document_uri="urn:document:abc123"
)
result = client.fetch_synthesis_content(synthesis, api=mock_api)
assert result == "librarian content"
def test_fetch_synthesis_no_content_or_document(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(uri="urn:syn:1")
result = client.fetch_synthesis_content(synthesis, api=None)
assert result == ""
def test_fetch_conclusion_inline(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
conclusion = Conclusion(uri="urn:conc:1", answer="42")
result = client.fetch_conclusion_content(conclusion, api=None)
assert result == "42"
def test_fetch_analysis_content_from_librarian(self):
mock_flow = MagicMock()
mock_api = MagicMock()
mock_library = MagicMock()
mock_api.library.return_value = mock_library
mock_library.get_document_content.side_effect = [
b"thought content",
b"observation content",
]
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
analysis = Analysis(
uri="urn:ana:1",
action="search",
thought_document_uri="urn:doc:thought",
observation_document_uri="urn:doc:obs",
)
client.fetch_analysis_content(analysis, api=mock_api)
assert analysis.thought == "thought content"
assert analysis.observation == "observation content"
def test_fetch_analysis_skips_when_inline_exists(self):
mock_flow = MagicMock()
mock_api = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
analysis = Analysis(
uri="urn:ana:1",
action="search",
thought="already have thought",
observation="already have observation",
thought_document_uri="urn:doc:thought",
observation_document_uri="urn:doc:obs",
)
client.fetch_analysis_content(analysis, api=mock_api)
# Should not call librarian since inline content exists
mock_api.library.assert_not_called()
class TestExplainabilityClientDetectSessionType:
def test_detect_agent_from_uri(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:agent:abc") == "agent"
def test_detect_graphrag_from_uri(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:question:abc") == "graphrag"
def test_detect_docrag_from_uri(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"

View file

@ -0,0 +1,785 @@
"""
Tests for provenance triple builder functions (extraction-time and query-time).
"""
import pytest
from unittest.mock import patch
from trustgraph.schema import Triple, Term, IRI, LITERAL, TRIPLE
from trustgraph.provenance.triples import (
set_graph,
document_triples,
derived_entity_triples,
subgraph_provenance_triples,
question_triples,
exploration_triples,
focus_triples,
synthesis_triples,
docrag_question_triples,
docrag_exploration_triples,
docrag_synthesis_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT,
PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME,
DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR,
TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER,
TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH,
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_CONTENT, TG_DOCUMENT,
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
GRAPH_SOURCE, GRAPH_RETRIEVAL,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
"""Find first triple matching predicate (and optionally subject)."""
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
"""Find all triples matching predicate (and optionally subject)."""
return [
t for t in triples
if t.p.iri == predicate and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
"""Check if subject has rdf:type rdf_type."""
for t in triples:
if (t.s.iri == subject and t.p.iri == RDF_TYPE
and t.o.type == IRI and t.o.iri == rdf_type):
return True
return False
# ---------------------------------------------------------------------------
# set_graph
# ---------------------------------------------------------------------------
class TestSetGraph:
def test_sets_graph_on_all_triples(self):
triples = [
Triple(
s=Term(type=IRI, iri="urn:s1"),
p=Term(type=IRI, iri="urn:p1"),
o=Term(type=LITERAL, value="v1"),
),
Triple(
s=Term(type=IRI, iri="urn:s2"),
p=Term(type=IRI, iri="urn:p2"),
o=Term(type=LITERAL, value="v2"),
),
]
result = set_graph(triples, GRAPH_RETRIEVAL)
assert len(result) == 2
for t in result:
assert t.g == GRAPH_RETRIEVAL
def test_does_not_modify_originals(self):
original = Triple(
s=Term(type=IRI, iri="urn:s"),
p=Term(type=IRI, iri="urn:p"),
o=Term(type=LITERAL, value="v"),
)
result = set_graph([original], "urn:graph:test")
assert original.g is None
assert result[0].g == "urn:graph:test"
def test_empty_list(self):
result = set_graph([], GRAPH_SOURCE)
assert result == []
def test_preserves_spo(self):
original = Triple(
s=Term(type=IRI, iri="urn:s"),
p=Term(type=IRI, iri="urn:p"),
o=Term(type=LITERAL, value="hello"),
)
result = set_graph([original], "urn:g")[0]
assert result.s.iri == "urn:s"
assert result.p.iri == "urn:p"
assert result.o.value == "hello"
# ---------------------------------------------------------------------------
# document_triples
# ---------------------------------------------------------------------------
class TestDocumentTriples:
DOC_URI = "https://example.com/doc/abc"
def test_minimal_document(self):
triples = document_triples(self.DOC_URI)
assert has_type(triples, self.DOC_URI, PROV_ENTITY)
assert has_type(triples, self.DOC_URI, TG_DOCUMENT_TYPE)
assert len(triples) == 2
def test_with_title(self):
triples = document_triples(self.DOC_URI, title="My Doc")
title_t = find_triple(triples, DC_TITLE)
assert title_t is not None
assert title_t.o.value == "My Doc"
# Title also creates an rdfs:label
label_t = find_triple(triples, RDFS_LABEL)
assert label_t is not None
assert label_t.o.value == "My Doc"
def test_with_source(self):
triples = document_triples(self.DOC_URI, source="https://source.com/f.pdf")
source_t = find_triple(triples, DC_SOURCE)
assert source_t is not None
assert source_t.o.type == IRI
assert source_t.o.iri == "https://source.com/f.pdf"
def test_with_date(self):
triples = document_triples(self.DOC_URI, date="2024-01-15")
date_t = find_triple(triples, DC_DATE)
assert date_t is not None
assert date_t.o.value == "2024-01-15"
def test_with_creator(self):
triples = document_triples(self.DOC_URI, creator="Alice")
creator_t = find_triple(triples, DC_CREATOR)
assert creator_t is not None
assert creator_t.o.value == "Alice"
def test_with_page_count(self):
triples = document_triples(self.DOC_URI, page_count=42)
pc_t = find_triple(triples, TG_PAGE_COUNT)
assert pc_t is not None
assert pc_t.o.value == "42"
def test_with_page_count_zero(self):
triples = document_triples(self.DOC_URI, page_count=0)
pc_t = find_triple(triples, TG_PAGE_COUNT)
assert pc_t is not None
assert pc_t.o.value == "0"
def test_with_mime_type(self):
triples = document_triples(self.DOC_URI, mime_type="application/pdf")
mt_t = find_triple(triples, TG_MIME_TYPE)
assert mt_t is not None
assert mt_t.o.value == "application/pdf"
def test_all_metadata(self):
triples = document_triples(
self.DOC_URI,
title="Test",
source="https://s.com",
date="2024-01-01",
creator="Bob",
page_count=10,
mime_type="application/pdf",
)
# 2 type triples + title + label + source + date + creator + page_count + mime_type
assert len(triples) == 9
def test_subject_is_doc_uri(self):
triples = document_triples(self.DOC_URI, title="T")
for t in triples:
assert t.s.iri == self.DOC_URI
# ---------------------------------------------------------------------------
# derived_entity_triples
# ---------------------------------------------------------------------------
class TestDerivedEntityTriples:
ENTITY_URI = "https://example.com/doc/abc/p1"
PARENT_URI = "https://example.com/doc/abc"
def test_page_entity_has_page_type(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
page_number=1,
timestamp="2024-01-01T00:00:00Z",
)
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
def test_chunk_entity_has_chunk_type(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"chunker", "1.0",
chunk_index=0,
timestamp="2024-01-01T00:00:00Z",
)
assert has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE)
def test_no_specific_type_without_page_or_chunk(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"component", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert not has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
assert not has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE)
def test_was_derived_from_parent(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ENTITY_URI)
assert derived is not None
assert derived.o.iri == self.PARENT_URI
def test_activity_created(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
# Entity was generated by an activity
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI)
assert gen is not None
act_uri = gen.o.iri
# Activity has correct type and metadata
assert has_type(triples, act_uri, PROV_ACTIVITY)
# Activity used the parent
used = find_triple(triples, PROV_USED, act_uri)
assert used is not None
assert used.o.iri == self.PARENT_URI
# Activity has component version
version = find_triple(triples, TG_COMPONENT_VERSION, act_uri)
assert version is not None
assert version.o.value == "1.0"
def test_agent_created(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
# Find the agent URI via wasAssociatedWith
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI)
act_uri = gen.o.iri
assoc = find_triple(triples, PROV_WAS_ASSOCIATED_WITH, act_uri)
assert assoc is not None
agt_uri = assoc.o.iri
assert has_type(triples, agt_uri, PROV_AGENT)
label = find_triple(triples, RDFS_LABEL, agt_uri)
assert label is not None
assert label.o.value == "pdf-extractor"
def test_timestamp_recorded(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
timestamp="2024-06-15T12:30:00Z",
)
ts = find_triple(triples, PROV_STARTED_AT_TIME)
assert ts is not None
assert ts.o.value == "2024-06-15T12:30:00Z"
def test_default_timestamp_generated(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
)
ts = find_triple(triples, PROV_STARTED_AT_TIME)
assert ts is not None
assert len(ts.o.value) > 0
def test_optional_label(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
label="Page 1",
timestamp="2024-01-01T00:00:00Z",
)
label = find_triple(triples, RDFS_LABEL, self.ENTITY_URI)
assert label is not None
assert label.o.value == "Page 1"
def test_page_number_recorded(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
page_number=3,
timestamp="2024-01-01T00:00:00Z",
)
pn = find_triple(triples, TG_PAGE_NUMBER, self.ENTITY_URI)
assert pn is not None
assert pn.o.value == "3"
def test_chunk_metadata_recorded(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"chunker", "2.0",
chunk_index=5,
char_offset=1000,
char_length=500,
chunk_size=512,
chunk_overlap=64,
timestamp="2024-01-01T00:00:00Z",
)
ci = find_triple(triples, TG_CHUNK_INDEX, self.ENTITY_URI)
assert ci is not None and ci.o.value == "5"
co = find_triple(triples, TG_CHAR_OFFSET, self.ENTITY_URI)
assert co is not None and co.o.value == "1000"
cl = find_triple(triples, TG_CHAR_LENGTH, self.ENTITY_URI)
assert cl is not None and cl.o.value == "500"
# chunk_size and chunk_overlap are on the activity, not the entity
cs = find_triple(triples, TG_CHUNK_SIZE)
assert cs is not None and cs.o.value == "512"
ov = find_triple(triples, TG_CHUNK_OVERLAP)
assert ov is not None and ov.o.value == "64"
# ---------------------------------------------------------------------------
# subgraph_provenance_triples
# ---------------------------------------------------------------------------
class TestSubgraphProvenanceTriples:
SG_URI = "https://trustgraph.ai/subgraph/test-sg"
CHUNK_URI = "https://example.com/doc/abc/p1/c0"
def _make_extracted_triple(self, s="urn:e:Alice", p="urn:r:knows", o="urn:e:Bob"):
return Triple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o),
)
def test_contains_quoted_triples(self):
extracted = [self._make_extracted_triple()]
triples = subgraph_provenance_triples(
self.SG_URI, extracted, self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
assert len(contains) == 1
assert contains[0].o.type == TRIPLE
assert contains[0].o.triple.s.iri == "urn:e:Alice"
assert contains[0].o.triple.p.iri == "urn:r:knows"
assert contains[0].o.triple.o.iri == "urn:e:Bob"
def test_multiple_extracted_triples(self):
extracted = [
self._make_extracted_triple("urn:e:A", "urn:r:x", "urn:e:B"),
self._make_extracted_triple("urn:e:C", "urn:r:y", "urn:e:D"),
self._make_extracted_triple("urn:e:E", "urn:r:z", "urn:e:F"),
]
triples = subgraph_provenance_triples(
self.SG_URI, extracted, self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
assert len(contains) == 3
def test_empty_extracted_triples(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
assert len(contains) == 0
# Should still have subgraph provenance metadata
assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE)
def test_subgraph_has_correct_types(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
assert has_type(triples, self.SG_URI, PROV_ENTITY)
assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE)
def test_derived_from_chunk(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SG_URI)
assert derived is not None
assert derived.o.iri == self.CHUNK_URI
def test_activity_and_agent(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.SG_URI)
assert gen is not None
act_uri = gen.o.iri
assert has_type(triples, act_uri, PROV_ACTIVITY)
used = find_triple(triples, PROV_USED, act_uri)
assert used is not None
assert used.o.iri == self.CHUNK_URI
version = find_triple(triples, TG_COMPONENT_VERSION, act_uri)
assert version is not None
assert version.o.value == "1.0"
def test_optional_llm_model(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
llm_model="claude-3-opus",
timestamp="2024-01-01T00:00:00Z",
)
llm = find_triple(triples, TG_LLM_MODEL)
assert llm is not None
assert llm.o.value == "claude-3-opus"
def test_no_llm_model_when_omitted(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
llm = find_triple(triples, TG_LLM_MODEL)
assert llm is None
def test_optional_ontology(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
ontology_uri="https://example.com/ontology/v1",
timestamp="2024-01-01T00:00:00Z",
)
ont = find_triple(triples, TG_ONTOLOGY)
assert ont is not None
assert ont.o.type == IRI
assert ont.o.iri == "https://example.com/ontology/v1"
# ---------------------------------------------------------------------------
# GraphRAG query-time triples
# ---------------------------------------------------------------------------
class TestQuestionTriples:
Q_URI = "urn:trustgraph:question:test-session"
def test_question_types(self):
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
assert has_type(triples, self.Q_URI, TG_QUESTION)
assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION)
def test_question_query_text(self):
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
query = find_triple(triples, TG_QUERY, self.Q_URI)
assert query is not None
assert query.o.value == "What is AI?"
def test_question_timestamp(self):
triples = question_triples(self.Q_URI, "Q", "2024-06-15T10:00:00Z")
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI)
assert ts is not None
assert ts.o.value == "2024-06-15T10:00:00Z"
def test_question_default_timestamp(self):
triples = question_triples(self.Q_URI, "Q")
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI)
assert ts is not None
assert len(ts.o.value) > 0
def test_question_label(self):
triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
label = find_triple(triples, RDFS_LABEL, self.Q_URI)
assert label is not None
assert label.o.value == "GraphRAG Question"
def test_question_triple_count(self):
triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
assert len(triples) == 6
class TestExplorationTriples:
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
Q_URI = "urn:trustgraph:question:test-session"
def test_exploration_types(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_exploration_generated_by_question(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
assert gen is not None
assert gen.o.iri == self.Q_URI
def test_exploration_edge_count(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None
assert ec.o.value == "15"
def test_exploration_zero_edges(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 0)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None
assert ec.o.value == "0"
def test_exploration_triple_count(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 10)
assert len(triples) == 5
class TestFocusTriples:
FOC_URI = "urn:trustgraph:prov:focus:test-session"
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
SESSION_ID = "test-session"
def test_focus_types(self):
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
assert has_type(triples, self.FOC_URI, PROV_ENTITY)
assert has_type(triples, self.FOC_URI, TG_FOCUS)
def test_focus_derived_from_exploration(self):
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FOC_URI)
assert derived is not None
assert derived.o.iri == self.EXP_URI
def test_focus_no_edges(self):
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
selected = find_triples(triples, TG_SELECTED_EDGE)
assert len(selected) == 0
def test_focus_with_edges_and_reasoning(self):
edges = [
{
"edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"),
"reasoning": "Alice is connected to Bob",
},
{
"edge": ("urn:e:Bob", "urn:r:worksAt", "urn:e:Acme"),
"reasoning": "Bob works at Acme",
},
]
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
# Two selectedEdge links
selected = find_triples(triples, TG_SELECTED_EDGE, self.FOC_URI)
assert len(selected) == 2
# Each edge selection has a quoted triple
edge_triples = find_triples(triples, TG_EDGE)
assert len(edge_triples) == 2
for et in edge_triples:
assert et.o.type == TRIPLE
# Each edge selection has reasoning
reasoning_triples = find_triples(triples, TG_REASONING)
assert len(reasoning_triples) == 2
def test_focus_edge_without_reasoning(self):
edges = [
{"edge": ("urn:e:A", "urn:r:x", "urn:e:B"), "reasoning": ""},
]
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
reasoning = find_triples(triples, TG_REASONING)
assert len(reasoning) == 0
def test_focus_edge_without_edge_data(self):
edges = [
{"edge": None, "reasoning": "some reasoning"},
]
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
selected = find_triples(triples, TG_SELECTED_EDGE)
assert len(selected) == 0
def test_focus_quoted_triple_content(self):
edges = [
{
"edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"),
"reasoning": "test",
},
]
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
edge_t = find_triple(triples, TG_EDGE)
qt = edge_t.o.triple
assert qt.s.iri == "urn:e:Alice"
assert qt.p.iri == "urn:r:knows"
assert qt.o.iri == "urn:e:Bob"
class TestSynthesisTriples:
SYN_URI = "urn:trustgraph:prov:synthesis:test-session"
FOC_URI = "urn:trustgraph:prov:focus:test-session"
def test_synthesis_types(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
def test_synthesis_derived_from_focus(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
assert derived is not None
assert derived.o.iri == self.FOC_URI
def test_synthesis_with_inline_content(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42")
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is not None
assert content.o.value == "The answer is 42"
def test_synthesis_with_document_reference(self):
triples = synthesis_triples(
self.SYN_URI, self.FOC_URI,
document_id="urn:trustgraph:question:abc/answer",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is not None
assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:question:abc/answer"
def test_synthesis_document_takes_precedence(self):
"""When both document_id and answer_text are provided, document_id wins."""
triples = synthesis_triples(
self.SYN_URI, self.FOC_URI,
answer_text="inline",
document_id="urn:doc:123",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is not None
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is None
def test_synthesis_no_content_or_document(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert content is None
assert doc is None
# ---------------------------------------------------------------------------
# DocumentRAG query-time triples
# ---------------------------------------------------------------------------
class TestDocRagQuestionTriples:
Q_URI = "urn:trustgraph:docrag:test-session"
def test_docrag_question_types(self):
triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z")
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
assert has_type(triples, self.Q_URI, TG_QUESTION)
assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION)
def test_docrag_question_label(self):
triples = docrag_question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
label = find_triple(triples, RDFS_LABEL, self.Q_URI)
assert label.o.value == "DocumentRAG Question"
def test_docrag_question_query_text(self):
triples = docrag_question_triples(self.Q_URI, "search query", "2024-01-01T00:00:00Z")
query = find_triple(triples, TG_QUERY, self.Q_URI)
assert query.o.value == "search query"
class TestDocRagExplorationTriples:
EXP_URI = "urn:trustgraph:docrag:test/exploration"
Q_URI = "urn:trustgraph:docrag:test"
def test_docrag_exploration_types(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_docrag_exploration_generated_by(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
assert gen.o.iri == self.Q_URI
def test_docrag_exploration_chunk_count(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7)
cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
assert cc.o.value == "7"
def test_docrag_exploration_without_chunk_ids(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3)
chunks = find_triples(triples, TG_SELECTED_CHUNK)
assert len(chunks) == 0
def test_docrag_exploration_with_chunk_ids(self):
chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids)
chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
assert len(chunks) == 3
chunk_uris = {t.o.iri for t in chunks}
assert chunk_uris == set(chunk_ids)
class TestDocRagSynthesisTriples:
SYN_URI = "urn:trustgraph:docrag:test/synthesis"
EXP_URI = "urn:trustgraph:docrag:test/exploration"
def test_docrag_synthesis_types(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
def test_docrag_synthesis_derived_from_exploration(self):
"""DocRAG skips the focus step — synthesis derives from exploration."""
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
assert derived.o.iri == self.EXP_URI
def test_docrag_synthesis_with_inline(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer")
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content.o.value == "answer"
def test_docrag_synthesis_with_document(self):
triples = docrag_synthesis_triples(
self.SYN_URI, self.EXP_URI, document_id="urn:doc:ans"
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc.o.iri == "urn:doc:ans"
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is None

View file

@ -0,0 +1,292 @@
"""
Tests for provenance URI generation functions.
"""
import pytest
from unittest.mock import patch
from trustgraph.provenance.uris import (
TRUSTGRAPH_BASE,
_encode_id,
document_uri,
page_uri,
chunk_uri_from_page,
chunk_uri_from_doc,
activity_uri,
subgraph_uri,
agent_uri,
question_uri,
exploration_uri,
focus_uri,
synthesis_uri,
edge_selection_uri,
agent_session_uri,
agent_iteration_uri,
agent_final_uri,
docrag_question_uri,
docrag_exploration_uri,
docrag_synthesis_uri,
)
class TestEncodeId:
"""Tests for the _encode_id helper."""
def test_plain_string(self):
assert _encode_id("abc123") == "abc123"
def test_string_with_spaces(self):
assert _encode_id("hello world") == "hello%20world"
def test_string_with_slashes(self):
assert _encode_id("a/b/c") == "a%2Fb%2Fc"
def test_integer_input(self):
assert _encode_id(42) == "42"
def test_empty_string(self):
assert _encode_id("") == ""
def test_special_characters(self):
result = _encode_id("name@domain.com")
assert "@" not in result or result == "name%40domain.com"
class TestDocumentUris:
"""Tests for document, page, and chunk URI generation."""
def test_document_uri_passthrough(self):
iri = "https://example.com/doc/123"
assert document_uri(iri) == iri
def test_page_uri_format(self):
result = page_uri("https://example.com/doc/123", 5)
assert result == "https://example.com/doc/123/p5"
def test_page_uri_page_zero(self):
result = page_uri("https://example.com/doc/123", 0)
assert result == "https://example.com/doc/123/p0"
def test_chunk_uri_from_page_format(self):
result = chunk_uri_from_page("https://example.com/doc/123", 2, 3)
assert result == "https://example.com/doc/123/p2/c3"
def test_chunk_uri_from_doc_format(self):
result = chunk_uri_from_doc("https://example.com/doc/123", 7)
assert result == "https://example.com/doc/123/c7"
def test_page_uri_preserves_doc_iri(self):
doc = "urn:isbn:978-3-16-148410-0"
result = page_uri(doc, 1)
assert result.startswith(doc)
def test_chunk_from_page_hierarchy(self):
"""Chunk URI should contain both page and chunk identifiers."""
result = chunk_uri_from_page("https://example.com/doc", 3, 5)
assert "/p3/" in result
assert result.endswith("/c5")
class TestActivityAndSubgraphUris:
"""Tests for activity_uri, subgraph_uri, and agent_uri."""
def test_activity_uri_with_id(self):
result = activity_uri("my-activity-id")
assert result == f"{TRUSTGRAPH_BASE}/activity/my-activity-id"
def test_activity_uri_auto_generates_uuid(self):
result = activity_uri()
assert result.startswith(f"{TRUSTGRAPH_BASE}/activity/")
# UUID part should be non-empty
uuid_part = result.split("/activity/")[1]
assert len(uuid_part) > 0
def test_activity_uri_unique_uuids(self):
r1 = activity_uri()
r2 = activity_uri()
assert r1 != r2
def test_activity_uri_encodes_special_chars(self):
result = activity_uri("id with spaces")
assert "id%20with%20spaces" in result
def test_subgraph_uri_with_id(self):
result = subgraph_uri("sg-123")
assert result == f"{TRUSTGRAPH_BASE}/subgraph/sg-123"
def test_subgraph_uri_auto_generates_uuid(self):
result = subgraph_uri()
assert result.startswith(f"{TRUSTGRAPH_BASE}/subgraph/")
uuid_part = result.split("/subgraph/")[1]
assert len(uuid_part) > 0
def test_subgraph_uri_unique_uuids(self):
r1 = subgraph_uri()
r2 = subgraph_uri()
assert r1 != r2
def test_agent_uri_format(self):
result = agent_uri("pdf-extractor")
assert result == f"{TRUSTGRAPH_BASE}/agent/pdf-extractor"
def test_agent_uri_encodes_special_chars(self):
result = agent_uri("my component")
assert "my%20component" in result
class TestGraphRagQueryUris:
"""Tests for GraphRAG query-time provenance URIs."""
FIXED_UUID = "550e8400-e29b-41d4-a716-446655440000"
def test_question_uri_with_session_id(self):
result = question_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:question:{self.FIXED_UUID}"
def test_question_uri_auto_generates(self):
result = question_uri()
assert result.startswith("urn:trustgraph:question:")
uuid_part = result.split("urn:trustgraph:question:")[1]
assert len(uuid_part) > 0
def test_question_uri_unique(self):
r1 = question_uri()
r2 = question_uri()
assert r1 != r2
def test_exploration_uri_format(self):
result = exploration_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:prov:exploration:{self.FIXED_UUID}"
def test_focus_uri_format(self):
result = focus_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:prov:focus:{self.FIXED_UUID}"
def test_synthesis_uri_format(self):
result = synthesis_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:prov:synthesis:{self.FIXED_UUID}"
def test_edge_selection_uri_format(self):
result = edge_selection_uri(self.FIXED_UUID, 3)
assert result == f"urn:trustgraph:prov:edge:{self.FIXED_UUID}:3"
def test_edge_selection_uri_zero_index(self):
result = edge_selection_uri(self.FIXED_UUID, 0)
assert result.endswith(":0")
def test_session_uris_share_session_id(self):
"""All URIs for a session should contain the same session ID."""
sid = self.FIXED_UUID
q = question_uri(sid)
e = exploration_uri(sid)
f = focus_uri(sid)
s = synthesis_uri(sid)
for uri in [q, e, f, s]:
assert sid in uri
class TestAgentProvenanceUris:
"""Tests for agent provenance URIs."""
FIXED_UUID = "661e8400-e29b-41d4-a716-446655440000"
def test_agent_session_uri_with_id(self):
result = agent_session_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}"
def test_agent_session_uri_auto_generates(self):
result = agent_session_uri()
assert result.startswith("urn:trustgraph:agent:")
def test_agent_session_uri_unique(self):
r1 = agent_session_uri()
r2 = agent_session_uri()
assert r1 != r2
def test_agent_iteration_uri_format(self):
result = agent_iteration_uri(self.FIXED_UUID, 1)
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/i1"
def test_agent_iteration_uri_numbering(self):
r1 = agent_iteration_uri(self.FIXED_UUID, 1)
r2 = agent_iteration_uri(self.FIXED_UUID, 2)
assert r1 != r2
assert r1.endswith("/i1")
assert r2.endswith("/i2")
def test_agent_final_uri_format(self):
result = agent_final_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/final"
def test_agent_uris_share_session_id(self):
sid = self.FIXED_UUID
session = agent_session_uri(sid)
iteration = agent_iteration_uri(sid, 1)
final = agent_final_uri(sid)
for uri in [session, iteration, final]:
assert sid in uri
class TestDocRagProvenanceUris:
"""Tests for Document RAG provenance URIs."""
FIXED_UUID = "772e8400-e29b-41d4-a716-446655440000"
def test_docrag_question_uri_with_id(self):
result = docrag_question_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}"
def test_docrag_question_uri_auto_generates(self):
result = docrag_question_uri()
assert result.startswith("urn:trustgraph:docrag:")
def test_docrag_question_uri_unique(self):
r1 = docrag_question_uri()
r2 = docrag_question_uri()
assert r1 != r2
def test_docrag_exploration_uri_format(self):
result = docrag_exploration_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/exploration"
def test_docrag_synthesis_uri_format(self):
result = docrag_synthesis_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/synthesis"
def test_docrag_uris_share_session_id(self):
sid = self.FIXED_UUID
q = docrag_question_uri(sid)
e = docrag_exploration_uri(sid)
s = docrag_synthesis_uri(sid)
for uri in [q, e, s]:
assert sid in uri
class TestUriNamespaceIsolation:
"""Verify that different provenance types use distinct URI namespaces."""
FIXED_UUID = "883e8400-e29b-41d4-a716-446655440000"
def test_graphrag_vs_agent_namespace(self):
graphrag = question_uri(self.FIXED_UUID)
agent = agent_session_uri(self.FIXED_UUID)
assert graphrag != agent
assert "question" in graphrag
assert "agent" in agent
def test_graphrag_vs_docrag_namespace(self):
graphrag = question_uri(self.FIXED_UUID)
docrag = docrag_question_uri(self.FIXED_UUID)
assert graphrag != docrag
def test_agent_vs_docrag_namespace(self):
agent = agent_session_uri(self.FIXED_UUID)
docrag = docrag_question_uri(self.FIXED_UUID)
assert agent != docrag
def test_extraction_vs_query_namespace(self):
"""Extraction URIs use https://, query URIs use urn:."""
ext = activity_uri(self.FIXED_UUID)
query = question_uri(self.FIXED_UUID)
assert ext.startswith("https://")
assert query.startswith("urn:")

View file

@ -0,0 +1,124 @@
"""
Tests for provenance vocabulary bootstrap.
"""
import pytest
from trustgraph.schema import Triple, Term, IRI, LITERAL
from trustgraph.provenance.vocabulary import (
get_vocabulary_triples,
PROV_CLASS_LABELS,
PROV_PREDICATE_LABELS,
DC_PREDICATE_LABELS,
SCHEMA_LABELS,
SKOS_LABELS,
TG_CLASS_LABELS,
TG_PREDICATE_LABELS,
)
from trustgraph.provenance.namespaces import (
RDFS_LABEL,
PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT,
PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME,
DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
)
class TestVocabularyTriples:
"""Tests for the vocabulary bootstrap function."""
def test_returns_list_of_triples(self):
result = get_vocabulary_triples()
assert isinstance(result, list)
assert len(result) > 0
for t in result:
assert isinstance(t, Triple)
def test_all_triples_are_label_triples(self):
"""Every vocabulary triple should use rdfs:label as predicate."""
for t in get_vocabulary_triples():
assert t.p.type == IRI
assert t.p.iri == RDFS_LABEL
def test_all_subjects_are_iris(self):
for t in get_vocabulary_triples():
assert t.s.type == IRI
assert len(t.s.iri) > 0
def test_all_objects_are_literals(self):
for t in get_vocabulary_triples():
assert t.o.type == LITERAL
assert len(t.o.value) > 0
def test_no_duplicate_subjects(self):
subjects = [t.s.iri for t in get_vocabulary_triples()]
assert len(subjects) == len(set(subjects))
def test_includes_prov_classes(self):
subjects = {t.s.iri for t in get_vocabulary_triples()}
assert PROV_ENTITY in subjects
assert PROV_ACTIVITY in subjects
assert PROV_AGENT in subjects
def test_includes_prov_predicates(self):
subjects = {t.s.iri for t in get_vocabulary_triples()}
assert PROV_WAS_DERIVED_FROM in subjects
assert PROV_WAS_GENERATED_BY in subjects
assert PROV_USED in subjects
assert PROV_WAS_ASSOCIATED_WITH in subjects
assert PROV_STARTED_AT_TIME in subjects
def test_includes_dc_predicates(self):
subjects = {t.s.iri for t in get_vocabulary_triples()}
assert DC_TITLE in subjects
assert DC_SOURCE in subjects
assert DC_DATE in subjects
assert DC_CREATOR in subjects
def test_includes_tg_classes(self):
subjects = {t.s.iri for t in get_vocabulary_triples()}
assert TG_DOCUMENT_TYPE in subjects
assert TG_PAGE_TYPE in subjects
assert TG_CHUNK_TYPE in subjects
assert TG_SUBGRAPH_TYPE in subjects
def test_component_lists_sum_to_total(self):
total = get_vocabulary_triples()
components = (
PROV_CLASS_LABELS +
PROV_PREDICATE_LABELS +
DC_PREDICATE_LABELS +
SCHEMA_LABELS +
SKOS_LABELS +
TG_CLASS_LABELS +
TG_PREDICATE_LABELS
)
assert len(total) == len(components)
def test_idempotent(self):
"""Calling twice should return equivalent triples."""
r1 = get_vocabulary_triples()
r2 = get_vocabulary_triples()
assert len(r1) == len(r2)
for t1, t2 in zip(r1, r2):
assert t1.s.iri == t2.s.iri
assert t1.o.value == t2.o.value
class TestNamespaceConstants:
"""Verify namespace constants are well-formed IRIs."""
def test_prov_namespace_prefix(self):
assert PROV_ENTITY.startswith("http://www.w3.org/ns/prov#")
def test_dc_namespace_prefix(self):
assert DC_TITLE.startswith("http://purl.org/dc/elements/1.1/")
def test_tg_namespace_prefix(self):
assert TG_DOCUMENT_TYPE.startswith("https://trustgraph.ai/ns/")
def test_rdfs_label_iri(self):
assert RDFS_LABEL == "http://www.w3.org/2000/01/rdf-schema#label"

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,309 @@
"""
Tests for RDF 1.2 type system primitives: Term dataclass (IRI, blank node,
typed literal, language-tagged literal, quoted triple), Triple/Quad dataclass
with named graph support, and the knowledge/defs helper types.
"""
import pytest
from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE
# ---------------------------------------------------------------------------
# Type constants
# ---------------------------------------------------------------------------
class TestTypeConstants:
def test_iri_constant(self):
assert IRI == "i"
def test_blank_constant(self):
assert BLANK == "b"
def test_literal_constant(self):
assert LITERAL == "l"
def test_triple_constant(self):
assert TRIPLE == "t"
def test_constants_are_distinct(self):
vals = {IRI, BLANK, LITERAL, TRIPLE}
assert len(vals) == 4
# ---------------------------------------------------------------------------
# IRI terms
# ---------------------------------------------------------------------------
class TestIriTerm:
def test_create_iri(self):
t = Term(type=IRI, iri="http://example.org/Alice")
assert t.type == IRI
assert t.iri == "http://example.org/Alice"
def test_iri_defaults_empty(self):
t = Term(type=IRI)
assert t.iri == ""
def test_iri_with_fragment(self):
t = Term(type=IRI, iri="http://example.org/ontology#Person")
assert "#Person" in t.iri
def test_iri_with_unicode(self):
t = Term(type=IRI, iri="http://example.org/概念")
assert "概念" in t.iri
def test_iri_other_fields_default(self):
t = Term(type=IRI, iri="http://example.org/x")
assert t.id == ""
assert t.value == ""
assert t.datatype == ""
assert t.language == ""
assert t.triple is None
# ---------------------------------------------------------------------------
# Blank node terms
# ---------------------------------------------------------------------------
class TestBlankNodeTerm:
def test_create_blank_node(self):
t = Term(type=BLANK, id="_:b0")
assert t.type == BLANK
assert t.id == "_:b0"
def test_blank_node_defaults_empty(self):
t = Term(type=BLANK)
assert t.id == ""
def test_blank_node_arbitrary_id(self):
t = Term(type=BLANK, id="node-abc-123")
assert t.id == "node-abc-123"
# ---------------------------------------------------------------------------
# Typed literals (XSD datatypes)
# ---------------------------------------------------------------------------
class TestTypedLiteral:
def test_plain_literal(self):
t = Term(type=LITERAL, value="hello")
assert t.type == LITERAL
assert t.value == "hello"
assert t.datatype == ""
assert t.language == ""
def test_xsd_integer(self):
t = Term(
type=LITERAL, value="42",
datatype="http://www.w3.org/2001/XMLSchema#integer",
)
assert t.value == "42"
assert "integer" in t.datatype
def test_xsd_boolean(self):
t = Term(
type=LITERAL, value="true",
datatype="http://www.w3.org/2001/XMLSchema#boolean",
)
assert t.datatype.endswith("#boolean")
def test_xsd_date(self):
t = Term(
type=LITERAL, value="2026-03-13",
datatype="http://www.w3.org/2001/XMLSchema#date",
)
assert t.value == "2026-03-13"
assert t.datatype.endswith("#date")
def test_xsd_double(self):
t = Term(
type=LITERAL, value="3.14",
datatype="http://www.w3.org/2001/XMLSchema#double",
)
assert t.datatype.endswith("#double")
def test_empty_value_literal(self):
t = Term(type=LITERAL, value="")
assert t.value == ""
# ---------------------------------------------------------------------------
# Language-tagged literals
# ---------------------------------------------------------------------------
class TestLanguageTaggedLiteral:
def test_english_tag(self):
t = Term(type=LITERAL, value="hello", language="en")
assert t.language == "en"
assert t.datatype == ""
def test_french_tag(self):
t = Term(type=LITERAL, value="bonjour", language="fr")
assert t.language == "fr"
def test_bcp47_subtag(self):
t = Term(type=LITERAL, value="colour", language="en-GB")
assert t.language == "en-GB"
def test_language_and_datatype_mutually_exclusive(self):
"""Both can be set on the dataclass, but semantically only one should be used."""
t = Term(type=LITERAL, value="x", language="en",
datatype="http://www.w3.org/2001/XMLSchema#string")
# Dataclass allows both — translators should respect mutual exclusivity
assert t.language == "en"
assert t.datatype != ""
# ---------------------------------------------------------------------------
# Quoted triples (RDF-star)
# ---------------------------------------------------------------------------
class TestQuotedTriple:
def test_term_with_nested_triple(self):
inner = Triple(
s=Term(type=IRI, iri="http://example.org/Alice"),
p=Term(type=IRI, iri="http://xmlns.com/foaf/0.1/knows"),
o=Term(type=IRI, iri="http://example.org/Bob"),
)
qt = Term(type=TRIPLE, triple=inner)
assert qt.type == TRIPLE
assert qt.triple is inner
assert qt.triple.s.iri == "http://example.org/Alice"
def test_quoted_triple_as_object(self):
"""A triple whose object is a quoted triple (RDF-star)."""
inner = Triple(
s=Term(type=IRI, iri="http://example.org/Hope"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="A feeling of expectation"),
)
outer = Triple(
s=Term(type=IRI, iri="urn:subgraph:123"),
p=Term(type=IRI, iri="http://trustgraph.ai/tg/contains"),
o=Term(type=TRIPLE, triple=inner),
)
assert outer.o.type == TRIPLE
assert outer.o.triple.o.value == "A feeling of expectation"
def test_quoted_triple_none(self):
t = Term(type=TRIPLE, triple=None)
assert t.triple is None
# ---------------------------------------------------------------------------
# Triple / Quad (named graph)
# ---------------------------------------------------------------------------
class TestTripleQuad:
def test_default_graph_is_none(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="val"),
)
assert t.g is None
def test_named_graph(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="val"),
g="urn:graph:source",
)
assert t.g == "urn:graph:source"
def test_empty_string_graph(self):
t = Triple(g="")
assert t.g == ""
def test_triple_with_all_none_terms(self):
t = Triple()
assert t.s is None
assert t.p is None
assert t.o is None
assert t.g is None
def test_triple_equality(self):
"""Dataclass equality based on field values."""
t1 = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/B"),
o=Term(type=LITERAL, value="C"),
)
t2 = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/B"),
o=Term(type=LITERAL, value="C"),
)
assert t1 == t2
# ---------------------------------------------------------------------------
# knowledge/defs helper types
# ---------------------------------------------------------------------------
class TestKnowledgeDefs:
def test_uri_type(self):
from trustgraph.knowledge.defs import Uri
u = Uri("http://example.org/x")
assert u.is_uri() is True
assert u.is_literal() is False
assert u.is_triple() is False
assert str(u) == "http://example.org/x"
def test_literal_type(self):
from trustgraph.knowledge.defs import Literal
l = Literal("hello world")
assert l.is_uri() is False
assert l.is_literal() is True
assert l.is_triple() is False
assert str(l) == "hello world"
def test_quoted_triple_type(self):
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
qt = QuotedTriple(
s=Uri("http://example.org/s"),
p=Uri("http://example.org/p"),
o=Literal("val"),
)
assert qt.is_uri() is False
assert qt.is_literal() is False
assert qt.is_triple() is True
assert qt.s == "http://example.org/s"
assert qt.o == "val"
def test_quoted_triple_repr(self):
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
qt = QuotedTriple(
s=Uri("http://example.org/A"),
p=Uri("http://example.org/B"),
o=Literal("C"),
)
r = repr(qt)
assert "<<" in r
assert ">>" in r
assert "http://example.org/A" in r
def test_quoted_triple_nested(self):
"""QuotedTriple can contain another QuotedTriple as object."""
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
inner = QuotedTriple(
s=Uri("http://example.org/s"),
p=Uri("http://example.org/p"),
o=Literal("v"),
)
outer = QuotedTriple(
s=Uri("http://example.org/s2"),
p=Uri("http://example.org/p2"),
o=inner,
)
assert outer.o.is_triple() is True

View file

@ -0,0 +1,217 @@
"""
Tests for RDF storage helper functions used by the Cassandra triple writer:
serialize_triple, get_term_value, get_term_otype, get_term_dtype, get_term_lang.
"""
import json
import pytest
from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE
from trustgraph.storage.triples.cassandra.write import (
serialize_triple,
get_term_value,
get_term_otype,
get_term_dtype,
get_term_lang,
)
# ---------------------------------------------------------------------------
# get_term_otype — maps Term.type to storage object type code
# ---------------------------------------------------------------------------
class TestGetTermOtype:
def test_iri_maps_to_u(self):
assert get_term_otype(Term(type=IRI, iri="http://x")) == "u"
def test_blank_maps_to_u(self):
assert get_term_otype(Term(type=BLANK, id="_:b0")) == "u"
def test_literal_maps_to_l(self):
assert get_term_otype(Term(type=LITERAL, value="x")) == "l"
def test_triple_maps_to_t(self):
assert get_term_otype(Term(type=TRIPLE)) == "t"
def test_none_defaults_to_u(self):
assert get_term_otype(None) == "u"
def test_unknown_type_defaults_to_u(self):
assert get_term_otype(Term(type="z")) == "u"
# ---------------------------------------------------------------------------
# get_term_dtype — extracts XSD datatype from literals
# ---------------------------------------------------------------------------
class TestGetTermDtype:
def test_literal_with_datatype(self):
t = Term(type=LITERAL, value="42",
datatype="http://www.w3.org/2001/XMLSchema#integer")
assert get_term_dtype(t) == "http://www.w3.org/2001/XMLSchema#integer"
def test_literal_without_datatype(self):
t = Term(type=LITERAL, value="hello")
assert get_term_dtype(t) == ""
def test_iri_returns_empty(self):
assert get_term_dtype(Term(type=IRI, iri="http://x")) == ""
def test_none_returns_empty(self):
assert get_term_dtype(None) == ""
# ---------------------------------------------------------------------------
# get_term_lang — extracts language tag from literals
# ---------------------------------------------------------------------------
class TestGetTermLang:
def test_literal_with_language(self):
t = Term(type=LITERAL, value="bonjour", language="fr")
assert get_term_lang(t) == "fr"
def test_literal_without_language(self):
t = Term(type=LITERAL, value="hello")
assert get_term_lang(t) == ""
def test_iri_returns_empty(self):
assert get_term_lang(Term(type=IRI, iri="http://x")) == ""
def test_none_returns_empty(self):
assert get_term_lang(None) == ""
def test_bcp47_subtag_preserved(self):
t = Term(type=LITERAL, value="colour", language="en-GB")
assert get_term_lang(t) == "en-GB"
# ---------------------------------------------------------------------------
# get_term_value — extracts string value from any Term
# ---------------------------------------------------------------------------
class TestGetTermValue:
def test_iri_returns_iri(self):
t = Term(type=IRI, iri="http://example.org/Alice")
assert get_term_value(t) == "http://example.org/Alice"
def test_literal_returns_value(self):
t = Term(type=LITERAL, value="hello")
assert get_term_value(t) == "hello"
def test_blank_returns_id(self):
t = Term(type=BLANK, id="_:b0")
assert get_term_value(t) == "_:b0"
def test_none_returns_none(self):
assert get_term_value(None) is None
def test_triple_returns_serialized_json(self):
inner = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="val"),
)
t = Term(type=TRIPLE, triple=inner)
result = get_term_value(t)
parsed = json.loads(result)
assert parsed["s"]["type"] == "i"
assert parsed["s"]["iri"] == "http://example.org/s"
assert parsed["o"]["value"] == "val"
# ---------------------------------------------------------------------------
# serialize_triple — full Triple → JSON serialization
# ---------------------------------------------------------------------------
class TestSerializeTriple:
def test_serialize_iri_triple(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/rel"),
o=Term(type=IRI, iri="http://example.org/B"),
)
result = json.loads(serialize_triple(t))
assert result["s"]["type"] == "i"
assert result["s"]["iri"] == "http://example.org/A"
assert result["p"]["iri"] == "http://example.org/rel"
assert result["o"]["type"] == "i"
def test_serialize_literal_object(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="hello"),
)
result = json.loads(serialize_triple(t))
assert result["o"]["type"] == "l"
assert result["o"]["value"] == "hello"
def test_serialize_typed_literal(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="42",
datatype="http://www.w3.org/2001/XMLSchema#integer"),
)
result = json.loads(serialize_triple(t))
assert result["o"]["datatype"] == "http://www.w3.org/2001/XMLSchema#integer"
def test_serialize_language_tagged_literal(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="bonjour", language="fr"),
)
result = json.loads(serialize_triple(t))
assert result["o"]["language"] == "fr"
def test_serialize_blank_node(self):
t = Triple(
s=Term(type=BLANK, id="_:b0"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="v"),
)
result = json.loads(serialize_triple(t))
assert result["s"]["type"] == "b"
assert result["s"]["id"] == "_:b0"
def test_serialize_nested_quoted_triple(self):
inner = Triple(
s=Term(type=IRI, iri="http://example.org/inner-s"),
p=Term(type=IRI, iri="http://example.org/inner-p"),
o=Term(type=LITERAL, value="inner-val"),
)
outer = Triple(
s=Term(type=IRI, iri="http://example.org/outer-s"),
p=Term(type=IRI, iri="http://example.org/outer-p"),
o=Term(type=TRIPLE, triple=inner),
)
result = json.loads(serialize_triple(outer))
nested = json.loads(result["o"]["triple"])
assert nested["s"]["iri"] == "http://example.org/inner-s"
assert nested["o"]["value"] == "inner-val"
def test_serialize_none_returns_none(self):
assert serialize_triple(None) is None
def test_serialize_none_terms(self):
t = Triple(s=None, p=None, o=None)
result = json.loads(serialize_triple(t))
assert result["s"] is None
assert result["p"] is None
assert result["o"] is None
def test_serialize_plain_literal_omits_datatype_and_language(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="plain"),
)
result = json.loads(serialize_triple(t))
assert "datatype" not in result["o"]
assert "language" not in result["o"]

View file

@ -0,0 +1,357 @@
"""
Tests for RDF wire format translators: TermTranslator and TripleTranslator
round-trip encoding for all RDF 1.2 term types (IRI, blank node, typed literal,
language-tagged literal, quoted triple) and named graph quads.
"""
import pytest
from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE
from trustgraph.messaging.translators.primitives import (
TermTranslator, TripleTranslator, SubgraphTranslator,
)
@pytest.fixture
def term_tx():
return TermTranslator()
@pytest.fixture
def triple_tx():
return TripleTranslator()
# ---------------------------------------------------------------------------
# TermTranslator — IRI
# ---------------------------------------------------------------------------
class TestTermTranslatorIri:
def test_iri_to_pulsar(self, term_tx):
data = {"t": "i", "i": "http://example.org/Alice"}
term = term_tx.to_pulsar(data)
assert term.type == IRI
assert term.iri == "http://example.org/Alice"
def test_iri_from_pulsar(self, term_tx):
term = Term(type=IRI, iri="http://example.org/Bob")
wire = term_tx.from_pulsar(term)
assert wire == {"t": "i", "i": "http://example.org/Bob"}
def test_iri_round_trip(self, term_tx):
original = Term(type=IRI, iri="http://example.org/round")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
assert restored == original
# ---------------------------------------------------------------------------
# TermTranslator — Blank node
# ---------------------------------------------------------------------------
class TestTermTranslatorBlank:
def test_blank_to_pulsar(self, term_tx):
data = {"t": "b", "d": "_:b42"}
term = term_tx.to_pulsar(data)
assert term.type == BLANK
assert term.id == "_:b42"
def test_blank_from_pulsar(self, term_tx):
term = Term(type=BLANK, id="_:node1")
wire = term_tx.from_pulsar(term)
assert wire == {"t": "b", "d": "_:node1"}
def test_blank_round_trip(self, term_tx):
original = Term(type=BLANK, id="_:x")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
assert restored == original
# ---------------------------------------------------------------------------
# TermTranslator — Typed literal (XSD)
# ---------------------------------------------------------------------------
class TestTermTranslatorTypedLiteral:
def test_plain_literal_to_pulsar(self, term_tx):
data = {"t": "l", "v": "hello"}
term = term_tx.to_pulsar(data)
assert term.type == LITERAL
assert term.value == "hello"
assert term.datatype == ""
assert term.language == ""
def test_xsd_integer_to_pulsar(self, term_tx):
data = {
"t": "l", "v": "42",
"dt": "http://www.w3.org/2001/XMLSchema#integer",
}
term = term_tx.to_pulsar(data)
assert term.value == "42"
assert term.datatype.endswith("#integer")
def test_typed_literal_from_pulsar(self, term_tx):
term = Term(
type=LITERAL, value="3.14",
datatype="http://www.w3.org/2001/XMLSchema#double",
)
wire = term_tx.from_pulsar(term)
assert wire["t"] == "l"
assert wire["v"] == "3.14"
assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double"
assert "ln" not in wire # No language tag
def test_typed_literal_round_trip(self, term_tx):
original = Term(
type=LITERAL, value="true",
datatype="http://www.w3.org/2001/XMLSchema#boolean",
)
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
assert restored == original
def test_plain_literal_omits_dt_and_ln(self, term_tx):
term = Term(type=LITERAL, value="x")
wire = term_tx.from_pulsar(term)
assert "dt" not in wire
assert "ln" not in wire
# ---------------------------------------------------------------------------
# TermTranslator — Language-tagged literal
# ---------------------------------------------------------------------------
class TestTermTranslatorLangLiteral:
def test_language_tag_to_pulsar(self, term_tx):
data = {"t": "l", "v": "bonjour", "ln": "fr"}
term = term_tx.to_pulsar(data)
assert term.value == "bonjour"
assert term.language == "fr"
def test_language_tag_from_pulsar(self, term_tx):
term = Term(type=LITERAL, value="colour", language="en-GB")
wire = term_tx.from_pulsar(term)
assert wire["ln"] == "en-GB"
assert "dt" not in wire # No datatype
def test_language_tag_round_trip(self, term_tx):
original = Term(type=LITERAL, value="hola", language="es")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
assert restored == original
# ---------------------------------------------------------------------------
# TermTranslator — Quoted triple (RDF-star)
# ---------------------------------------------------------------------------
class TestTermTranslatorQuotedTriple:
def test_quoted_triple_to_pulsar(self, term_tx):
data = {
"t": "t",
"tr": {
"s": {"t": "i", "i": "http://example.org/Alice"},
"p": {"t": "i", "i": "http://xmlns.com/foaf/0.1/knows"},
"o": {"t": "i", "i": "http://example.org/Bob"},
},
}
term = term_tx.to_pulsar(data)
assert term.type == TRIPLE
assert term.triple is not None
assert term.triple.s.iri == "http://example.org/Alice"
assert term.triple.o.iri == "http://example.org/Bob"
def test_quoted_triple_from_pulsar(self, term_tx):
inner = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="val"),
)
term = Term(type=TRIPLE, triple=inner)
wire = term_tx.from_pulsar(term)
assert wire["t"] == "t"
assert "tr" in wire
assert wire["tr"]["s"]["i"] == "http://example.org/s"
assert wire["tr"]["o"]["v"] == "val"
def test_quoted_triple_round_trip(self, term_tx):
inner = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/B"),
o=Term(type=LITERAL, value="C", language="en"),
)
original = Term(type=TRIPLE, triple=inner)
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
assert restored.type == TRIPLE
assert restored.triple.s == original.triple.s
assert restored.triple.o == original.triple.o
def test_quoted_triple_none_triple(self, term_tx):
term = Term(type=TRIPLE, triple=None)
wire = term_tx.from_pulsar(term)
assert wire == {"t": "t"}
# And back
restored = term_tx.to_pulsar(wire)
assert restored.type == TRIPLE
assert restored.triple is None
def test_quoted_triple_with_literal_object(self, term_tx):
data = {
"t": "t",
"tr": {
"s": {"t": "i", "i": "http://example.org/Hope"},
"p": {"t": "i", "i": "http://www.w3.org/2004/02/skos/core#definition"},
"o": {"t": "l", "v": "A feeling of expectation"},
},
}
term = term_tx.to_pulsar(data)
assert term.triple.o.type == LITERAL
assert term.triple.o.value == "A feeling of expectation"
# ---------------------------------------------------------------------------
# TermTranslator — Edge cases
# ---------------------------------------------------------------------------
class TestTermTranslatorEdgeCases:
def test_unknown_type(self, term_tx):
data = {"t": "z"}
term = term_tx.to_pulsar(data)
assert term.type == "z"
def test_empty_type(self, term_tx):
data = {}
term = term_tx.to_pulsar(data)
assert term.type == ""
def test_missing_iri_field(self, term_tx):
data = {"t": "i"}
term = term_tx.to_pulsar(data)
assert term.iri == ""
def test_missing_literal_fields(self, term_tx):
data = {"t": "l"}
term = term_tx.to_pulsar(data)
assert term.value == ""
assert term.datatype == ""
assert term.language == ""
# ---------------------------------------------------------------------------
# TripleTranslator
# ---------------------------------------------------------------------------
class TestTripleTranslator:
def test_triple_to_pulsar(self, triple_tx):
data = {
"s": {"t": "i", "i": "http://example.org/s"},
"p": {"t": "i", "i": "http://example.org/p"},
"o": {"t": "l", "v": "object"},
}
triple = triple_tx.to_pulsar(data)
assert triple.s.iri == "http://example.org/s"
assert triple.o.value == "object"
assert triple.g is None
def test_triple_from_pulsar(self, triple_tx):
triple = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/B"),
o=Term(type=LITERAL, value="C"),
)
wire = triple_tx.from_pulsar(triple)
assert wire["s"]["t"] == "i"
assert wire["o"]["v"] == "C"
assert "g" not in wire
def test_quad_with_named_graph(self, triple_tx):
data = {
"s": {"t": "i", "i": "http://example.org/s"},
"p": {"t": "i", "i": "http://example.org/p"},
"o": {"t": "l", "v": "val"},
"g": "urn:graph:source",
}
quad = triple_tx.to_pulsar(data)
assert quad.g == "urn:graph:source"
def test_quad_from_pulsar_includes_graph(self, triple_tx):
quad = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="v"),
g="urn:graph:retrieval",
)
wire = triple_tx.from_pulsar(quad)
assert wire["g"] == "urn:graph:retrieval"
def test_quad_round_trip(self, triple_tx):
original = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="v"),
g="urn:graph:source",
)
wire = triple_tx.from_pulsar(original)
restored = triple_tx.to_pulsar(wire)
assert restored == original
def test_none_graph_omitted_from_wire(self, triple_tx):
triple = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="v"),
g=None,
)
wire = triple_tx.from_pulsar(triple)
assert "g" not in wire
def test_missing_terms_handled(self, triple_tx):
data = {}
triple = triple_tx.to_pulsar(data)
assert triple.s is None
assert triple.p is None
assert triple.o is None
# ---------------------------------------------------------------------------
# SubgraphTranslator
# ---------------------------------------------------------------------------
class TestSubgraphTranslator:
def test_subgraph_round_trip(self):
tx = SubgraphTranslator()
triples = [
Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/rel"),
o=Term(type=LITERAL, value="v1"),
),
Triple(
s=Term(type=IRI, iri="http://example.org/B"),
p=Term(type=IRI, iri="http://example.org/rel"),
o=Term(type=IRI, iri="http://example.org/C"),
g="urn:graph:source",
),
]
wire_list = tx.from_pulsar(triples)
assert len(wire_list) == 2
assert wire_list[1]["g"] == "urn:graph:source"
restored = tx.to_pulsar(wire_list)
assert len(restored) == 2
assert restored[0] == triples[0]
assert restored[1] == triples[1]
def test_empty_subgraph(self):
tx = SubgraphTranslator()
assert tx.to_pulsar([]) == []
assert tx.from_pulsar([]) == []

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,144 @@
"""
Tests for pipeline metadata preservation: DocumentMetadata and
ProcessingMetadata round-trip through translators, field preservation,
and default handling.
"""
import pytest
from trustgraph.schema import DocumentMetadata, ProcessingMetadata, Triple, Term, IRI
from trustgraph.messaging.translators.metadata import (
DocumentMetadataTranslator,
ProcessingMetadataTranslator,
)
# ---------------------------------------------------------------------------
# DocumentMetadata translator
# ---------------------------------------------------------------------------
class TestDocumentMetadataTranslator:
def setup_method(self):
self.tx = DocumentMetadataTranslator()
def test_full_round_trip(self):
data = {
"id": "doc-123",
"time": 1710000000,
"kind": "application/pdf",
"title": "Test Document",
"comments": "No comments",
"metadata": [],
"user": "alice",
"tags": ["finance", "q4"],
"parent-id": "doc-100",
"document-type": "page",
}
obj = self.tx.to_pulsar(data)
assert obj.id == "doc-123"
assert obj.time == 1710000000
assert obj.kind == "application/pdf"
assert obj.title == "Test Document"
assert obj.user == "alice"
assert obj.tags == ["finance", "q4"]
assert obj.parent_id == "doc-100"
assert obj.document_type == "page"
wire = self.tx.from_pulsar(obj)
assert wire["id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["parent-id"] == "doc-100"
assert wire["document-type"] == "page"
def test_defaults_for_missing_fields(self):
obj = self.tx.to_pulsar({})
assert obj.parent_id == ""
assert obj.document_type == "source"
def test_metadata_triples_preserved(self):
triple_wire = [{
"s": {"t": "i", "i": "http://example.org/s"},
"p": {"t": "i", "i": "http://example.org/p"},
"o": {"t": "i", "i": "http://example.org/o"},
}]
data = {"metadata": triple_wire}
obj = self.tx.to_pulsar(data)
assert len(obj.metadata) == 1
assert obj.metadata[0].s.iri == "http://example.org/s"
def test_none_metadata_handled(self):
data = {"metadata": None}
obj = self.tx.to_pulsar(data)
assert obj.metadata == []
def test_empty_tags_preserved(self):
data = {"tags": []}
obj = self.tx.to_pulsar(data)
wire = self.tx.from_pulsar(obj)
assert wire["tags"] == []
def test_falsy_fields_omitted_from_wire(self):
"""Empty string fields should be omitted from wire format."""
obj = DocumentMetadata(id="", time=0, user="")
wire = self.tx.from_pulsar(obj)
assert "id" not in wire
assert "user" not in wire
# ---------------------------------------------------------------------------
# ProcessingMetadata translator
# ---------------------------------------------------------------------------
class TestProcessingMetadataTranslator:
def setup_method(self):
self.tx = ProcessingMetadataTranslator()
def test_full_round_trip(self):
data = {
"id": "proc-1",
"document-id": "doc-123",
"time": 1710000000,
"flow": "default",
"user": "alice",
"collection": "my-collection",
"tags": ["tag1"],
}
obj = self.tx.to_pulsar(data)
assert obj.id == "proc-1"
assert obj.document_id == "doc-123"
assert obj.flow == "default"
assert obj.user == "alice"
assert obj.collection == "my-collection"
assert obj.tags == ["tag1"]
wire = self.tx.from_pulsar(obj)
assert wire["id"] == "proc-1"
assert wire["document-id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["collection"] == "my-collection"
def test_missing_fields_use_defaults(self):
obj = self.tx.to_pulsar({})
assert obj.id is None
assert obj.user is None
assert obj.collection is None
def test_tags_none_omitted(self):
obj = ProcessingMetadata(tags=None)
wire = self.tx.from_pulsar(obj)
assert "tags" not in wire
def test_tags_empty_list_preserved(self):
obj = ProcessingMetadata(tags=[])
wire = self.tx.from_pulsar(obj)
assert wire["tags"] == []
def test_user_and_collection_preserved(self):
"""Core pipeline routing fields must survive round-trip."""
data = {"user": "bob", "collection": "research"}
obj = self.tx.to_pulsar(data)
wire = self.tx.from_pulsar(obj)
assert wire["user"] == "bob"
assert wire["collection"] == "research"

View file

@ -0,0 +1,314 @@
"""
Tests for null embedding protection: empty/None vector skipping, entity
validation, dimension-aware collection creation, and query-time empty
vector handling.
Tests the pure functions and logic without Qdrant connections.
"""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.schema import Term, IRI, LITERAL, BLANK
# ---------------------------------------------------------------------------
# Graph embeddings: get_term_value
# ---------------------------------------------------------------------------
class TestGraphEmbeddingsGetTermValue:
def test_iri_returns_iri(self):
from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
t = Term(type=IRI, iri="http://example.org/x")
assert get_term_value(t) == "http://example.org/x"
def test_literal_returns_value(self):
from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
t = Term(type=LITERAL, value="hello")
assert get_term_value(t) == "hello"
def test_blank_returns_id(self):
from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
t = Term(type=BLANK, id="_:b0")
assert get_term_value(t) == "_:b0"
def test_none_returns_none(self):
from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
assert get_term_value(None) is None
def test_blank_with_value_fallback(self):
from trustgraph.storage.graph_embeddings.qdrant.write import get_term_value
t = Term(type=BLANK, id="", value="fallback")
assert get_term_value(t) == "fallback"
# ---------------------------------------------------------------------------
# Document embeddings: null vector protection
# ---------------------------------------------------------------------------
class TestDocEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_empty_vector_skipped(self):
"""Embeddings with empty vectors should be silently skipped."""
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
# Mock collection_exists for config check
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
emb.chunk_id = "chunk-1"
emb.vector = [] # Empty vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
# No upsert should be called
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_none_vector_skipped(self):
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
emb.chunk_id = "chunk-1"
emb.vector = None # None vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_empty_chunk_id_skipped(self):
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
emb.chunk_id = "" # Empty chunk ID
emb.vector = [0.1, 0.2, 0.3]
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_valid_embedding_upserted(self):
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
emb = MagicMock()
emb.chunk_id = "chunk-1"
emb.vector = [0.1, 0.2, 0.3]
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
proc.qdrant.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_dimension_in_collection_name(self):
"""Collection name should include vector dimension."""
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "alice"
msg.metadata.collection = "docs"
emb = MagicMock()
emb.chunk_id = "c1"
emb.vector = [0.0] * 384 # 384-dim vector
msg.chunks = [emb]
await proc.store_document_embeddings(msg)
call_args = proc.qdrant.upsert.call_args
assert "d_alice_docs_384" in call_args[1]["collection_name"]
# ---------------------------------------------------------------------------
# Graph embeddings: null entity and vector protection
# ---------------------------------------------------------------------------
class TestGraphEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_empty_entity_skipped(self):
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
entity.entity = Term(type=IRI, iri="") # Empty IRI
entity.vector = [0.1, 0.2, 0.3]
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_none_entity_skipped(self):
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
entity.entity = None # Null entity
entity.vector = [0.1, 0.2, 0.3]
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_empty_vector_skipped(self):
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
entity.entity = Term(type=IRI, iri="http://example.org/x")
entity.vector = [] # Empty vector
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_valid_entity_and_vector_upserted(self):
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "col1"
entity = MagicMock()
entity.entity = Term(type=IRI, iri="http://example.org/Alice")
entity.vector = [0.1, 0.2, 0.3]
entity.chunk_id = "c1"
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
proc.qdrant.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_lazy_collection_creation_on_new_dimension(self):
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = False
proc.collection_exists = MagicMock(return_value=True)
msg = MagicMock()
msg.metadata.user = "alice"
msg.metadata.collection = "graphs"
entity = MagicMock()
entity.entity = Term(type=IRI, iri="http://example.org/x")
entity.vector = [0.0] * 768
entity.chunk_id = ""
msg.entities = [entity]
await proc.store_graph_embeddings(msg)
# Collection should be created with correct dimension
proc.qdrant.create_collection.assert_called_once()
create_args = proc.qdrant.create_collection.call_args
assert create_args[1]["collection_name"] == "t_alice_graphs_768"
# ---------------------------------------------------------------------------
# Collection validation — deleted-while-in-flight protection
# ---------------------------------------------------------------------------
class TestCollectionValidation:
@pytest.mark.asyncio
async def test_doc_embeddings_dropped_for_deleted_collection(self):
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.collection_exists = MagicMock(return_value=False)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "deleted-col"
msg.chunks = [MagicMock()]
await proc.store_document_embeddings(msg)
proc.qdrant.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_graph_embeddings_dropped_for_deleted_collection(self):
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.collection_exists = MagicMock(return_value=False)
msg = MagicMock()
msg.metadata.user = "user1"
msg.metadata.collection = "deleted-col"
msg.entities = [MagicMock()]
await proc.store_graph_embeddings(msg)
proc.qdrant.upsert.assert_not_called()

View file

@ -0,0 +1,153 @@
"""
Tests for retry and backoff strategies: Consumer rate-limit retry loop,
timeout expiry, TooManyRequests exception propagation, and configurable
retry parameters.
"""
import asyncio
import time
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.exceptions import TooManyRequests
from trustgraph.base.consumer import Consumer
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_consumer(rate_limit_retry_time=10, rate_limit_timeout=7200):
"""Create a Consumer with minimal mocking."""
consumer = Consumer.__new__(Consumer)
consumer.rate_limit_retry_time = rate_limit_retry_time
consumer.rate_limit_timeout = rate_limit_timeout
consumer.metrics = None
consumer.consumer = MagicMock()
return consumer
# ---------------------------------------------------------------------------
# TooManyRequests exception
# ---------------------------------------------------------------------------
class TestTooManyRequestsException:
def test_is_exception(self):
assert issubclass(TooManyRequests, Exception)
def test_with_message(self):
err = TooManyRequests("rate limited")
assert "rate limited" in str(err)
def test_without_message(self):
err = TooManyRequests()
assert isinstance(err, TooManyRequests)
# ---------------------------------------------------------------------------
# Consumer retry configuration
# ---------------------------------------------------------------------------
class TestConsumerRetryConfig:
def test_default_retry_time(self):
consumer = _make_consumer()
assert consumer.rate_limit_retry_time == 10
def test_default_timeout(self):
consumer = _make_consumer()
assert consumer.rate_limit_timeout == 7200
def test_custom_retry_time(self):
consumer = _make_consumer(rate_limit_retry_time=5)
assert consumer.rate_limit_retry_time == 5
def test_custom_timeout(self):
consumer = _make_consumer(rate_limit_timeout=300)
assert consumer.rate_limit_timeout == 300
# ---------------------------------------------------------------------------
# Rate limit metrics
# ---------------------------------------------------------------------------
class TestRateLimitMetrics:
def test_metrics_rate_limit_called(self):
"""Metrics should record rate limit events when available."""
consumer = _make_consumer()
consumer.metrics = MagicMock()
# Simulate what the consumer does on rate limit
consumer.metrics.rate_limit()
consumer.metrics.rate_limit.assert_called_once()
# ---------------------------------------------------------------------------
# Message acknowledgment on error
# ---------------------------------------------------------------------------
class TestMessageAckOnError:
def test_consumer_has_negative_acknowledge(self):
"""Consumer backend should support negative acknowledgment."""
consumer = _make_consumer()
msg = MagicMock()
# Simulate negative ack (what happens on timeout expiry)
consumer.consumer.negative_acknowledge(msg)
consumer.consumer.negative_acknowledge.assert_called_once_with(msg)
# ---------------------------------------------------------------------------
# TooManyRequests propagation across services
# ---------------------------------------------------------------------------
class TestTooManyRequestsPropagation:
def test_llm_service_propagates(self):
"""LLM services should re-raise TooManyRequests for consumer retry."""
with pytest.raises(TooManyRequests):
raise TooManyRequests()
def test_embeddings_service_propagates(self):
"""Embeddings services should re-raise TooManyRequests for consumer retry."""
with pytest.raises(TooManyRequests):
try:
raise TooManyRequests("rate limited")
except TooManyRequests as e:
# Re-raise pattern used in services
assert isinstance(e, TooManyRequests)
raise
def test_too_many_requests_not_caught_by_generic(self):
"""TooManyRequests should be distinguishable from generic exceptions."""
caught_specific = False
try:
raise TooManyRequests("rate limited")
except TooManyRequests:
caught_specific = True
except Exception:
pass
assert caught_specific
# ---------------------------------------------------------------------------
# Client-side error type mapping
# ---------------------------------------------------------------------------
class TestClientErrorTypeMapping:
def test_too_many_requests_wire_type(self):
"""The wire format error type for rate limiting is 'too-many-requests'."""
from trustgraph.schema import Error
err = Error(type="too-many-requests", message="slow down")
assert err.type == "too-many-requests"
def test_generic_error_wire_type(self):
from trustgraph.schema import Error
err = Error(type="internal-error", message="something broke")
assert err.type == "internal-error"
assert err.type != "too-many-requests"

View file

@ -0,0 +1,233 @@
"""
Tests for message queue subscriber resilience: unexpected message handling,
orphaned message detection, backpressure strategies, graceful draining,
and timeout recovery.
"""
import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.base.subscriber import Subscriber
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_subscriber(max_size=10, backpressure_strategy="block",
drain_timeout=5.0):
"""Create a Subscriber without connecting to any backend."""
backend = MagicMock()
sub = Subscriber(
backend=backend,
topic="test-topic",
subscription="test-sub",
consumer_name="test",
max_size=max_size,
backpressure_strategy=backpressure_strategy,
drain_timeout=drain_timeout,
)
sub.consumer = MagicMock()
return sub
def _make_msg(id=None, value="test-value"):
"""Create a mock message with optional properties."""
msg = MagicMock()
if id is not None:
msg.properties.return_value = {"id": id}
else:
msg.properties.side_effect = KeyError("id")
msg.value.return_value = value
return msg
# ---------------------------------------------------------------------------
# Message property extraction resilience
# ---------------------------------------------------------------------------
class TestMessagePropertyResilience:
@pytest.mark.asyncio
async def test_missing_id_property_handled(self):
"""Messages without 'id' property should not crash."""
sub = _make_subscriber()
msg = MagicMock()
msg.properties.side_effect = Exception("no properties")
msg.value.return_value = "some-value"
# Should not raise
await sub._process_message(msg)
# Message should still be acknowledged
sub.consumer.acknowledge.assert_called_once_with(msg)
@pytest.mark.asyncio
async def test_message_with_valid_id_delivered(self):
"""Messages with matching subscriber ID should be delivered."""
sub = _make_subscriber()
q = await sub.subscribe("req-1")
msg = _make_msg(id="req-1", value="response-data")
await sub._process_message(msg)
assert not q.empty()
assert q.get_nowait() == "response-data"
sub.consumer.acknowledge.assert_called_once()
# ---------------------------------------------------------------------------
# Orphaned message handling
# ---------------------------------------------------------------------------
class TestOrphanedMessages:
@pytest.mark.asyncio
async def test_orphaned_message_acknowledged(self):
"""Messages with no matching waiter should still be acknowledged."""
sub = _make_subscriber()
msg = _make_msg(id="unknown-id", value="orphan")
await sub._process_message(msg)
# Orphaned message is acknowledged (not negative-acknowledged)
sub.consumer.acknowledge.assert_called_once_with(msg)
@pytest.mark.asyncio
async def test_orphaned_message_not_queued(self):
"""Orphaned messages should not appear in any subscriber queue."""
sub = _make_subscriber()
q = await sub.subscribe("req-1")
msg = _make_msg(id="different-id", value="orphan")
await sub._process_message(msg)
assert q.empty()
# ---------------------------------------------------------------------------
# Backpressure strategies
# ---------------------------------------------------------------------------
class TestBackpressureStrategies:
@pytest.mark.asyncio
async def test_drop_new_rejects_when_full(self):
"""drop_new strategy should reject new messages when queue is full."""
sub = _make_subscriber(max_size=1, backpressure_strategy="drop_new")
q = await sub.subscribe("req-1")
# Fill the queue
msg1 = _make_msg(id="req-1", value="first")
await sub._process_message(msg1)
assert q.qsize() == 1
# Second message should be dropped
msg2 = _make_msg(id="req-1", value="second")
await sub._process_message(msg2)
# Queue still has only the first message
assert q.qsize() == 1
assert q.get_nowait() == "first"
@pytest.mark.asyncio
async def test_drop_oldest_evicts_when_full(self):
"""drop_oldest strategy should evict oldest message when full."""
sub = _make_subscriber(max_size=1, backpressure_strategy="drop_oldest")
q = await sub.subscribe("req-1")
msg1 = _make_msg(id="req-1", value="first")
await sub._process_message(msg1)
msg2 = _make_msg(id="req-1", value="second")
await sub._process_message(msg2)
# Queue should have the newer message
assert q.qsize() == 1
assert q.get_nowait() == "second"
@pytest.mark.asyncio
async def test_block_strategy_delivers(self):
"""block strategy should deliver messages normally."""
sub = _make_subscriber(max_size=10, backpressure_strategy="block")
q = await sub.subscribe("req-1")
msg = _make_msg(id="req-1", value="data")
await sub._process_message(msg)
assert q.get_nowait() == "data"
# ---------------------------------------------------------------------------
# Full subscribers (subscribe_all)
# ---------------------------------------------------------------------------
class TestFullSubscribers:
@pytest.mark.asyncio
async def test_subscribe_all_receives_all_messages(self):
sub = _make_subscriber()
q = await sub.subscribe_all("listener-1")
msg = _make_msg(id="any-id", value="broadcast")
await sub._process_message(msg)
assert q.get_nowait() == "broadcast"
@pytest.mark.asyncio
async def test_multiple_full_subscribers_all_receive(self):
sub = _make_subscriber()
q1 = await sub.subscribe_all("l1")
q2 = await sub.subscribe_all("l2")
msg = _make_msg(id="any", value="data")
await sub._process_message(msg)
assert q1.get_nowait() == "data"
assert q2.get_nowait() == "data"
# ---------------------------------------------------------------------------
# Subscribe / unsubscribe lifecycle
# ---------------------------------------------------------------------------
class TestSubscribeLifecycle:
@pytest.mark.asyncio
async def test_unsubscribe_removes_queue(self):
sub = _make_subscriber()
await sub.subscribe("req-1")
await sub.unsubscribe("req-1")
assert "req-1" not in sub.q
@pytest.mark.asyncio
async def test_unsubscribe_nonexistent_is_noop(self):
sub = _make_subscriber()
await sub.unsubscribe("nonexistent") # Should not raise
@pytest.mark.asyncio
async def test_unsubscribe_all_removes_queue(self):
sub = _make_subscriber()
await sub.subscribe_all("l1")
await sub.unsubscribe_all("l1")
assert "l1" not in sub.full
# ---------------------------------------------------------------------------
# Pending ack tracking
# ---------------------------------------------------------------------------
class TestPendingAckTracking:
@pytest.mark.asyncio
async def test_processed_message_cleared_from_pending(self):
sub = _make_subscriber()
msg = _make_msg(id="req-1", value="data")
await sub._process_message(msg)
# After processing, pending_acks should be empty
assert len(sub.pending_acks) == 0

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,296 @@
"""
Tests for row embeddings query service: collection naming, query execution,
index filtering, result conversion, and error handling.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.schema import (
RowEmbeddingsRequest, RowEmbeddingsResponse,
RowIndexMatch, Error,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_processor(qdrant_client=None):
"""Create a Processor without full FlowProcessor init."""
from trustgraph.query.row_embeddings.qdrant.service import Processor
proc = Processor.__new__(Processor)
proc.qdrant = qdrant_client or MagicMock()
return proc
def _make_request(vector=None, user="test-user", collection="test-col",
schema_name="customers", limit=10, index_name=None):
return RowEmbeddingsRequest(
vector=vector or [0.1, 0.2, 0.3],
user=user,
collection=collection,
schema_name=schema_name,
limit=limit,
index_name=index_name or "",
)
def _make_search_point(index_name, index_value, text, score):
point = MagicMock()
point.payload = {
"index_name": index_name,
"index_value": index_value,
"text": text,
}
point.score = score
return point
# ---------------------------------------------------------------------------
# sanitize_name
# ---------------------------------------------------------------------------
class TestSanitizeName:
def test_simple_name(self):
proc = _make_processor()
assert proc.sanitize_name("customers") == "customers"
def test_special_chars_replaced(self):
proc = _make_processor()
assert proc.sanitize_name("my-schema.v2") == "my_schema_v2"
def test_leading_digit_prefixed(self):
proc = _make_processor()
result = proc.sanitize_name("123schema")
assert result.startswith("r_")
assert "123schema" in result
def test_uppercase_lowercased(self):
proc = _make_processor()
assert proc.sanitize_name("MySchema") == "myschema"
def test_spaces_replaced(self):
proc = _make_processor()
assert proc.sanitize_name("my schema") == "my_schema"
# ---------------------------------------------------------------------------
# find_collection
# ---------------------------------------------------------------------------
class TestFindCollection:
def test_finds_matching_collection(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_test_user_test_col_customers_384"
mock_collections = MagicMock()
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers")
# Prefix: rows_test_user_test_col_customers_
assert result == "rows_test_user_test_col_customers_384"
def test_returns_none_when_no_match(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_other_user_other_col_schema_768"
mock_collections = MagicMock()
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers")
assert result is None
def test_returns_none_on_error(self):
proc = _make_processor()
proc.qdrant.get_collections.side_effect = Exception("connection error")
result = proc.find_collection("user", "col", "schema")
assert result is None
# ---------------------------------------------------------------------------
# query_row_embeddings
# ---------------------------------------------------------------------------
class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_empty_vector_returns_empty(self):
proc = _make_processor()
request = _make_request(vector=[])
result = await proc.query_row_embeddings(request)
assert result == []
@pytest.mark.asyncio
async def test_no_collection_returns_empty(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value=None)
request = _make_request()
result = await proc.query_row_embeddings(request)
assert result == []
@pytest.mark.asyncio
async def test_successful_query_returns_matches(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
points = [
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
_make_search_point("address", ["123 Main St"], "123 Main St", 0.82),
]
mock_result = MagicMock()
mock_result.points = points
proc.qdrant.query_points.return_value = mock_result
request = _make_request()
result = await proc.query_row_embeddings(request)
assert len(result) == 2
assert isinstance(result[0], RowIndexMatch)
assert result[0].index_name == "name"
assert result[0].index_value == ["Alice Smith"]
assert result[0].score == 0.95
assert result[1].index_name == "address"
@pytest.mark.asyncio
async def test_index_name_filter_applied(self):
"""When index_name is specified, a Qdrant filter should be used."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
mock_result = MagicMock()
mock_result.points = []
proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="address")
await proc.query_row_embeddings(request)
call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is not None
@pytest.mark.asyncio
async def test_no_index_name_no_filter(self):
"""When index_name is empty, no filter should be applied."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
mock_result = MagicMock()
mock_result.points = []
proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="")
await proc.query_row_embeddings(request)
call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is None
@pytest.mark.asyncio
async def test_missing_payload_fields_default(self):
"""Points with missing payload fields should use defaults."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
point = MagicMock()
point.payload = {} # Empty payload
point.score = 0.5
mock_result = MagicMock()
mock_result.points = [point]
proc.qdrant.query_points.return_value = mock_result
request = _make_request()
result = await proc.query_row_embeddings(request)
assert len(result) == 1
assert result[0].index_name == ""
assert result[0].index_value == []
assert result[0].text == ""
@pytest.mark.asyncio
async def test_qdrant_error_propagates(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.qdrant.query_points.side_effect = Exception("qdrant down")
request = _make_request()
with pytest.raises(Exception, match="qdrant down"):
await proc.query_row_embeddings(request)
# ---------------------------------------------------------------------------
# on_message handler
# ---------------------------------------------------------------------------
class TestOnMessage:
@pytest.mark.asyncio
async def test_successful_message_sends_response(self):
proc = _make_processor()
proc.query_row_embeddings = AsyncMock(return_value=[
RowIndexMatch(index_name="name", index_value=["Alice"],
text="Alice", score=0.9),
])
mock_pub = AsyncMock()
flow = lambda name: mock_pub
msg = MagicMock()
msg.value.return_value = _make_request()
msg.properties.return_value = {"id": "req-1"}
await proc.on_message(msg, MagicMock(), flow)
sent = mock_pub.send.call_args[0][0]
assert isinstance(sent, RowEmbeddingsResponse)
assert sent.error is None
assert len(sent.matches) == 1
@pytest.mark.asyncio
async def test_error_sends_error_response(self):
proc = _make_processor()
proc.query_row_embeddings = AsyncMock(
side_effect=Exception("query failed")
)
mock_pub = AsyncMock()
flow = lambda name: mock_pub
msg = MagicMock()
msg.value.return_value = _make_request()
msg.properties.return_value = {"id": "req-2"}
await proc.on_message(msg, MagicMock(), flow)
sent = mock_pub.send.call_args[0][0]
assert sent.error is not None
assert sent.error.type == "row-embeddings-query-error"
assert "query failed" in sent.error.message
assert sent.matches == []
@pytest.mark.asyncio
async def test_message_id_preserved(self):
proc = _make_processor()
proc.query_row_embeddings = AsyncMock(return_value=[])
mock_pub = AsyncMock()
flow = lambda name: mock_pub
msg = MagicMock()
msg.value.return_value = _make_request()
msg.properties.return_value = {"id": "unique-42"}
await proc.on_message(msg, MagicMock(), flow)
props = mock_pub.send.call_args[1]["properties"]
assert props["id"] == "unique-42"

View file

@ -0,0 +1,235 @@
"""
Tests for structured data type detection: CSV, JSON, XML format detection,
CSV option detection (delimiter, header), and helper functions.
"""
import pytest
from trustgraph.retrieval.structured_diag.type_detector import (
detect_data_type,
_check_json_format,
_check_xml_format,
_check_csv_format,
_check_csv_with_delimiter,
detect_csv_options,
_is_numeric,
)
# ---------------------------------------------------------------------------
# detect_data_type (top-level dispatcher)
# ---------------------------------------------------------------------------
class TestDetectDataType:
def test_empty_string_returns_none(self):
detected, confidence = detect_data_type("")
assert detected is None
assert confidence == 0.0
def test_whitespace_only_returns_none(self):
detected, confidence = detect_data_type(" \n \t ")
assert detected is None
assert confidence == 0.0
def test_none_returns_none(self):
detected, confidence = detect_data_type(None)
assert detected is None
assert confidence == 0.0
def test_json_object_detected(self):
detected, confidence = detect_data_type('{"name": "Alice"}')
assert detected == "json"
assert confidence > 0.5
def test_json_array_detected(self):
detected, confidence = detect_data_type('[{"id": 1}, {"id": 2}]')
assert detected == "json"
assert confidence > 0.5
def test_xml_with_declaration_detected(self):
detected, confidence = detect_data_type('<?xml version="1.0"?><root></root>')
assert detected == "xml"
assert confidence > 0.5
def test_xml_without_declaration_detected(self):
detected, confidence = detect_data_type('<root><item>val</item></root>')
assert detected == "xml"
assert confidence > 0.5
def test_csv_detected(self):
data = "name,age,city\nAlice,30,NYC\nBob,25,LA"
detected, confidence = detect_data_type(data)
assert detected == "csv"
assert confidence > 0.5
def test_plain_text_falls_through_to_csv(self):
"""Non-JSON/XML text defaults to CSV detection."""
detected, confidence = detect_data_type("just some text")
assert detected == "csv"
# ---------------------------------------------------------------------------
# _check_json_format
# ---------------------------------------------------------------------------
class TestCheckJsonFormat:
def test_valid_json_object(self):
assert _check_json_format('{"key": "value"}') > 0.9
def test_valid_json_array_of_objects(self):
assert _check_json_format('[{"id": 1}, {"id": 2}]') >= 0.9
def test_valid_json_array_of_primitives(self):
score = _check_json_format('[1, 2, 3]')
assert score > 0.5
assert score < 0.9 # Lower confidence for non-object arrays
def test_empty_json_object(self):
assert _check_json_format('{}') > 0.5
def test_invalid_json(self):
assert _check_json_format('{invalid json}') == 0.0
def test_non_json_starting_char(self):
assert _check_json_format('hello world') == 0.0
def test_empty_array(self):
score = _check_json_format('[]')
assert score > 0.0 # Parsed successfully but empty
# ---------------------------------------------------------------------------
# _check_xml_format
# ---------------------------------------------------------------------------
class TestCheckXmlFormat:
def test_valid_xml(self):
assert _check_xml_format('<root><item>val</item></root>') == 0.9
def test_xml_with_declaration(self):
xml = '<?xml version="1.0"?><root><item>test</item></root>'
assert _check_xml_format(xml) == 0.9
def test_malformed_xml(self):
score = _check_xml_format('<root><unclosed>')
# Has < and </ check fails since no closing tag
assert score < 0.9
def test_not_xml(self):
assert _check_xml_format('just text') == 0.0
def test_incomplete_xml_tag(self):
score = _check_xml_format('<root>')
# Starts with < but no closing tag
assert score <= 0.1
# ---------------------------------------------------------------------------
# _check_csv_format and _check_csv_with_delimiter
# ---------------------------------------------------------------------------
class TestCheckCsvFormat:
def test_valid_csv_comma(self):
data = "name,age,city\nAlice,30,NYC\nBob,25,LA"
assert _check_csv_format(data) > 0.7
def test_valid_csv_semicolon(self):
data = "name;age;city\nAlice;30;NYC\nBob;25;LA"
assert _check_csv_format(data) > 0.7
def test_valid_csv_tab(self):
data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA"
assert _check_csv_format(data) > 0.7
def test_valid_csv_pipe(self):
data = "name|age|city\nAlice|30|NYC\nBob|25|LA"
assert _check_csv_format(data) > 0.7
def test_single_line_not_csv(self):
assert _check_csv_format("just one line") == 0.0
def test_single_column_not_csv(self):
data = "a\nb\nc"
assert _check_csv_with_delimiter(data, ",") == 0.0
def test_inconsistent_columns_low_score(self):
data = "a,b,c\n1,2\n3,4,5,6"
score = _check_csv_with_delimiter(data, ",")
assert score < 0.7
def test_many_rows_higher_score(self):
rows = ["name,age,city"] + [f"person{i},{20+i},city{i}" for i in range(20)]
data = "\n".join(rows)
score = _check_csv_format(data)
assert score > 0.8
# ---------------------------------------------------------------------------
# detect_csv_options
# ---------------------------------------------------------------------------
class TestDetectCsvOptions:
def test_comma_delimiter_detected(self):
data = "name,age,city\nAlice,30,NYC\nBob,25,LA"
options = detect_csv_options(data)
assert options["delimiter"] == ","
def test_semicolon_delimiter_detected(self):
data = "name;age;city\nAlice;30;NYC\nBob;25;LA"
options = detect_csv_options(data)
assert options["delimiter"] == ";"
def test_tab_delimiter_detected(self):
data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA"
options = detect_csv_options(data)
assert options["delimiter"] == "\t"
def test_header_detected_when_first_row_text(self):
data = "name,age,salary\nAlice,30,50000\nBob,25,45000"
options = detect_csv_options(data)
assert options["has_header"] is True
def test_no_header_when_all_numeric(self):
data = "1,2,3\n4,5,6\n7,8,9"
options = detect_csv_options(data)
assert options["has_header"] is False
def test_single_line_returns_defaults(self):
options = detect_csv_options("just one line")
assert options["delimiter"] == ","
assert options["has_header"] is True
def test_encoding_default(self):
data = "a,b\n1,2"
options = detect_csv_options(data)
assert options["encoding"] == "utf-8"
# ---------------------------------------------------------------------------
# _is_numeric helper
# ---------------------------------------------------------------------------
class TestIsNumeric:
def test_integer(self):
assert _is_numeric("42") is True
def test_float(self):
assert _is_numeric("3.14") is True
def test_negative(self):
assert _is_numeric("-10") is True
def test_text(self):
assert _is_numeric("hello") is False
def test_empty(self):
assert _is_numeric("") is False
def test_whitespace_padded(self):
assert _is_numeric(" 42 ") is True

View file

@ -0,0 +1,182 @@
"""
Tests for Azure OpenAI streaming: model/temperature override during streaming,
RateLimitError TooManyRequests conversion, chunk iteration, and final token
count emission.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.model.text_completion.azure_openai.llm import Processor
from trustgraph.base import LlmChunk
from trustgraph.exceptions import TooManyRequests
def _make_processor(mock_azure_openai_class, model="gpt-4"):
"""Create a Processor with mocked base classes."""
with patch('trustgraph.base.async_processor.AsyncProcessor.__init__',
return_value=None), \
patch('trustgraph.base.llm_service.LlmService.__init__',
return_value=None):
proc = Processor(
endpoint="https://test.openai.azure.com/",
token="test-token",
api_version="2024-12-01-preview",
model=model,
temperature=0.0,
max_output=4192,
concurrency=1,
taskgroup=AsyncMock(),
id="test-processor",
)
return proc
def _make_stream_chunk(content=None, usage=None):
"""Create a mock streaming chunk."""
chunk = MagicMock()
if content:
chunk.choices = [MagicMock()]
chunk.choices[0].delta.content = content
else:
chunk.choices = []
chunk.usage = usage
return chunk
class TestAzureOpenAIStreaming(IsolatedAsyncioTestCase):
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
async def test_streaming_yields_chunks(self, mock_azure_class):
mock_client = MagicMock()
mock_azure_class.return_value = mock_client
proc = _make_processor(mock_azure_class)
usage = MagicMock()
usage.prompt_tokens = 10
usage.completion_tokens = 5
stream_data = [
_make_stream_chunk(content="Hello"),
_make_stream_chunk(content=" world"),
_make_stream_chunk(usage=usage),
]
mock_client.chat.completions.create.return_value = iter(stream_data)
results = []
async for chunk in proc.generate_content_stream("sys", "user"):
results.append(chunk)
assert len(results) == 3 # 2 content + 1 final
assert results[0].text == "Hello"
assert results[0].is_final is False
assert results[1].text == " world"
assert results[2].is_final is True
assert results[2].in_token == 10
assert results[2].out_token == 5
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
async def test_streaming_model_override(self, mock_azure_class):
mock_client = MagicMock()
mock_azure_class.return_value = mock_client
proc = _make_processor(mock_azure_class, model="gpt-4")
usage = MagicMock()
usage.prompt_tokens = 5
usage.completion_tokens = 2
stream_data = [
_make_stream_chunk(content="ok"),
_make_stream_chunk(usage=usage),
]
mock_client.chat.completions.create.return_value = iter(stream_data)
results = []
async for chunk in proc.generate_content_stream(
"sys", "user", model="gpt-4o"
):
results.append(chunk)
# All chunks carry overridden model
for r in results:
assert r.model == "gpt-4o"
# Verify API call used overridden model
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "gpt-4o"
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
async def test_streaming_temperature_override(self, mock_azure_class):
mock_client = MagicMock()
mock_azure_class.return_value = mock_client
proc = _make_processor(mock_azure_class)
usage = MagicMock()
usage.prompt_tokens = 5
usage.completion_tokens = 2
stream_data = [_make_stream_chunk(usage=usage)]
mock_client.chat.completions.create.return_value = iter(stream_data)
async for _ in proc.generate_content_stream(
"sys", "user", temperature=0.7
):
pass
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["temperature"] == 0.7
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
async def test_streaming_rate_limit_raises_too_many_requests(self, mock_azure_class):
from openai import RateLimitError
mock_client = MagicMock()
mock_azure_class.return_value = mock_client
proc = _make_processor(mock_azure_class)
mock_client.chat.completions.create.side_effect = RateLimitError(
"Rate limit exceeded", response=MagicMock(), body=None
)
with pytest.raises(TooManyRequests):
async for _ in proc.generate_content_stream("sys", "user"):
pass
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
async def test_streaming_generic_exception_propagates(self, mock_azure_class):
mock_client = MagicMock()
mock_azure_class.return_value = mock_client
proc = _make_processor(mock_azure_class)
mock_client.chat.completions.create.side_effect = Exception("API down")
with pytest.raises(Exception, match="API down"):
async for _ in proc.generate_content_stream("sys", "user"):
pass
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
async def test_streaming_passes_stream_options(self, mock_azure_class):
mock_client = MagicMock()
mock_azure_class.return_value = mock_client
proc = _make_processor(mock_azure_class)
usage = MagicMock()
usage.prompt_tokens = 0
usage.completion_tokens = 0
stream_data = [_make_stream_chunk(usage=usage)]
mock_client.chat.completions.create.return_value = iter(stream_data)
async for _ in proc.generate_content_stream("sys", "user"):
pass
call_kwargs = mock_client.chat.completions.create.call_args[1]
assert call_kwargs["stream"] is True
assert call_kwargs["stream_options"] == {"include_usage": True}
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
async def test_supports_streaming(self, mock_azure_class):
mock_client = MagicMock()
mock_azure_class.return_value = mock_client
proc = _make_processor(mock_azure_class)
assert proc.supports_streaming() is True

View file

@ -0,0 +1,199 @@
"""
Tests for Azure serverless endpoint streaming: model override during streaming,
HTTP 429 during streaming, SSE chunk parsing, and final token count emission.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.model.text_completion.azure.llm import Processor
from trustgraph.base import LlmChunk
from trustgraph.exceptions import TooManyRequests
def _make_processor(mock_requests, model="AzureAI", temperature=0.0):
"""Create a Processor with mocked base classes."""
with patch('trustgraph.base.async_processor.AsyncProcessor.__init__',
return_value=None), \
patch('trustgraph.base.llm_service.LlmService.__init__',
return_value=None):
proc = Processor(
endpoint="https://test.azure.com/v1/chat/completions",
token="test-token",
temperature=temperature,
max_output=4192,
model=model,
concurrency=1,
taskgroup=AsyncMock(),
id="test-processor",
)
return proc
def _sse_lines(*data_items):
"""Build SSE byte lines from data items. '[DONE]' is appended."""
lines = []
for item in data_items:
if isinstance(item, dict):
lines.append(f"data: {json.dumps(item)}".encode())
else:
lines.append(f"data: {item}".encode())
lines.append(b"data: [DONE]")
return lines
class TestAzureServerlessStreaming(IsolatedAsyncioTestCase):
@patch('trustgraph.model.text_completion.azure.llm.requests')
async def test_streaming_yields_chunks(self, mock_requests):
proc = _make_processor(mock_requests)
chunks = [
{"choices": [{"delta": {"content": "Hello"}}]},
{"choices": [{"delta": {"content": " world"}}]},
{"usage": {"prompt_tokens": 10, "completion_tokens": 5}},
]
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = _sse_lines(*chunks)
mock_requests.post.return_value = mock_response
results = []
async for chunk in proc.generate_content_stream("sys", "user"):
results.append(chunk)
# Content chunks + final chunk
assert len(results) == 3
assert results[0].text == "Hello"
assert results[0].is_final is False
assert results[1].text == " world"
assert results[1].is_final is False
assert results[2].is_final is True
assert results[2].in_token == 10
assert results[2].out_token == 5
@patch('trustgraph.model.text_completion.azure.llm.requests')
async def test_streaming_model_override(self, mock_requests):
proc = _make_processor(mock_requests, model="default-model")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = _sse_lines(
{"choices": [{"delta": {"content": "ok"}}]},
{"usage": {"prompt_tokens": 5, "completion_tokens": 2}},
)
mock_requests.post.return_value = mock_response
results = []
async for chunk in proc.generate_content_stream(
"sys", "user", model="override-model"
):
results.append(chunk)
# All chunks should carry the overridden model name
for r in results:
assert r.model == "override-model"
# Verify the request body used the overridden model
call_args = mock_requests.post.call_args
body = json.loads(call_args[1]["data"])
assert body["model"] == "override-model"
@patch('trustgraph.model.text_completion.azure.llm.requests')
async def test_streaming_temperature_override(self, mock_requests):
proc = _make_processor(mock_requests, temperature=0.0)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = _sse_lines(
{"choices": [{"delta": {"content": "ok"}}]},
{"usage": {"prompt_tokens": 5, "completion_tokens": 2}},
)
mock_requests.post.return_value = mock_response
results = []
async for chunk in proc.generate_content_stream(
"sys", "user", temperature=0.9
):
results.append(chunk)
call_args = mock_requests.post.call_args
body = json.loads(call_args[1]["data"])
assert body["temperature"] == 0.9
@patch('trustgraph.model.text_completion.azure.llm.requests')
async def test_streaming_429_raises_too_many_requests(self, mock_requests):
proc = _make_processor(mock_requests)
mock_response = MagicMock()
mock_response.status_code = 429
mock_requests.post.return_value = mock_response
with pytest.raises(TooManyRequests):
async for _ in proc.generate_content_stream("sys", "user"):
pass
@patch('trustgraph.model.text_completion.azure.llm.requests')
async def test_streaming_http_error_raises_runtime(self, mock_requests):
proc = _make_processor(mock_requests)
mock_response = MagicMock()
mock_response.status_code = 503
mock_response.text = "Service Unavailable"
mock_requests.post.return_value = mock_response
with pytest.raises(RuntimeError, match="HTTP 503"):
async for _ in proc.generate_content_stream("sys", "user"):
pass
@patch('trustgraph.model.text_completion.azure.llm.requests')
async def test_streaming_includes_stream_options(self, mock_requests):
"""Verify stream=True and stream_options in request body."""
proc = _make_processor(mock_requests)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = _sse_lines(
{"usage": {"prompt_tokens": 0, "completion_tokens": 0}},
)
mock_requests.post.return_value = mock_response
async for _ in proc.generate_content_stream("sys", "user"):
pass
call_args = mock_requests.post.call_args
body = json.loads(call_args[1]["data"])
assert body["stream"] is True
assert body["stream_options"]["include_usage"] is True
@patch('trustgraph.model.text_completion.azure.llm.requests')
async def test_streaming_malformed_json_skipped(self, mock_requests):
"""Malformed JSON chunks should be skipped, not crash the stream."""
proc = _make_processor(mock_requests)
mock_response = MagicMock()
mock_response.status_code = 200
lines = [
b"data: {not valid json}",
f'data: {json.dumps({"choices": [{"delta": {"content": "ok"}}]})}'.encode(),
f'data: {json.dumps({"usage": {"prompt_tokens": 1, "completion_tokens": 1}})}'.encode(),
b"data: [DONE]",
]
mock_response.iter_lines.return_value = lines
mock_requests.post.return_value = mock_response
results = []
async for chunk in proc.generate_content_stream("sys", "user"):
results.append(chunk)
# Should get the valid content chunk + final chunk
assert any(r.text == "ok" for r in results)
assert results[-1].is_final is True
@patch('trustgraph.model.text_completion.azure.llm.requests')
async def test_streaming_supports_streaming_flag(self, mock_requests):
proc = _make_processor(mock_requests)
assert proc.supports_streaming() is True

View file

@ -0,0 +1,140 @@
"""
Cross-provider rate limit contract tests: verify that every LLM provider
that handles rate limits converts its provider-specific exception to
TooManyRequests consistently.
Also tests the client-side error translation in the base client.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.exceptions import TooManyRequests
class TestAzureServerless429(IsolatedAsyncioTestCase):
"""Azure serverless endpoint: HTTP 429 → TooManyRequests"""
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_http_429_raises_too_many_requests(self, _llm, _async, mock_requests):
from trustgraph.model.text_completion.azure.llm import Processor
proc = Processor(
endpoint="https://test.azure.com/v1/chat",
token="t", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_response = MagicMock()
mock_response.status_code = 429
mock_requests.post.return_value = mock_response
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
class TestAzureOpenAIRateLimit(IsolatedAsyncioTestCase):
"""Azure OpenAI: openai.RateLimitError → TooManyRequests"""
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cls):
from openai import RateLimitError
from trustgraph.model.text_completion.azure_openai.llm import Processor
mock_client = MagicMock()
mock_cls.return_value = mock_client
proc = Processor(
endpoint="https://test.openai.azure.com/", token="t",
model="gpt-4", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_client.chat.completions.create.side_effect = RateLimitError(
"rate limited", response=MagicMock(), body=None
)
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
class TestOpenAIRateLimit(IsolatedAsyncioTestCase):
"""OpenAI: openai.RateLimitError → TooManyRequests"""
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cls):
from openai import RateLimitError
from trustgraph.model.text_completion.openai.llm import Processor
mock_client = MagicMock()
mock_cls.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_client.chat.completions.create.side_effect = RateLimitError(
"rate limited", response=MagicMock(), body=None
)
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
class TestClaudeRateLimit(IsolatedAsyncioTestCase):
"""Claude/Anthropic: anthropic.RateLimitError → TooManyRequests"""
@patch('trustgraph.model.text_completion.claude.llm.anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_anthropic):
from trustgraph.model.text_completion.claude.llm import Processor
mock_client = MagicMock()
mock_anthropic.Anthropic.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_anthropic.RateLimitError = type("RateLimitError", (Exception,), {})
mock_client.messages.create.side_effect = mock_anthropic.RateLimitError(
"rate limited"
)
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
class TestCohereRateLimit(IsolatedAsyncioTestCase):
"""Cohere: cohere.TooManyRequestsError → TooManyRequests"""
@patch('trustgraph.model.text_completion.cohere.llm.cohere')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere):
from trustgraph.model.text_completion.cohere.llm import Processor
mock_client = MagicMock()
mock_cohere.Client.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_cohere.TooManyRequestsError = type(
"TooManyRequestsError", (Exception,), {}
)
mock_client.chat.side_effect = mock_cohere.TooManyRequestsError(
"rate limited"
)
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
class TestClientSideRateLimitTranslation:
"""Client base class: error type 'too-many-requests' → TooManyRequests"""
def test_error_type_mapping(self):
"""The wire format error type string is 'too-many-requests'."""
from trustgraph.schema import Error
err = Error(type="too-many-requests", message="slow down")
assert err.type == "too-many-requests"