Merge remote-tracking branch 'origin/master' into ts-port

This commit is contained in:
elpresidank 2026-04-11 23:56:34 -05:00
commit f4d6e49217
270 changed files with 19608 additions and 4096 deletions

View file

@ -87,10 +87,11 @@ def sample_message_data():
"history": []
},
"AgentResponse": {
"answer": "Machine learning is a subset of AI.",
"chunk_type": "answer",
"content": "Machine learning is a subset of AI.",
"end_of_message": True,
"end_of_dialog": True,
"error": None,
"thought": "I need to provide information about machine learning.",
"observation": None
},
"Metadata": {
"id": "test-doc-123",

View file

@ -38,7 +38,7 @@ class TestDocumentEmbeddingsRequestContract:
assert request.user == "test_user"
assert request.collection == "test_collection"
def test_request_translator_to_pulsar(self):
def test_request_translator_decode(self):
"""Test request translator converts dict to Pulsar schema"""
translator = DocumentEmbeddingsRequestTranslator()
@ -49,7 +49,7 @@ class TestDocumentEmbeddingsRequestContract:
"collection": "custom_collection"
}
result = translator.to_pulsar(data)
result = translator.decode(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vector == [0.1, 0.2, 0.3, 0.4]
@ -57,7 +57,7 @@ class TestDocumentEmbeddingsRequestContract:
assert result.user == "custom_user"
assert result.collection == "custom_collection"
def test_request_translator_to_pulsar_with_defaults(self):
def test_request_translator_decode_with_defaults(self):
"""Test request translator uses correct defaults"""
translator = DocumentEmbeddingsRequestTranslator()
@ -66,7 +66,7 @@ class TestDocumentEmbeddingsRequestContract:
# No limit, user, or collection provided
}
result = translator.to_pulsar(data)
result = translator.decode(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vector == [0.1, 0.2]
@ -74,7 +74,7 @@ class TestDocumentEmbeddingsRequestContract:
assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default
def test_request_translator_from_pulsar(self):
def test_request_translator_encode(self):
"""Test request translator converts Pulsar schema to dict"""
translator = DocumentEmbeddingsRequestTranslator()
@ -85,7 +85,7 @@ class TestDocumentEmbeddingsRequestContract:
collection="test_collection"
)
result = translator.from_pulsar(request)
result = translator.encode(request)
assert isinstance(result, dict)
assert result["vector"] == [0.5, 0.6]
@ -134,7 +134,7 @@ class TestDocumentEmbeddingsResponseContract:
assert response.error == error
assert response.chunks == []
def test_response_translator_from_pulsar_with_chunks(self):
def test_response_translator_encode_with_chunks(self):
"""Test response translator converts Pulsar schema with chunks to dict"""
translator = DocumentEmbeddingsResponseTranslator()
@ -147,7 +147,7 @@ class TestDocumentEmbeddingsResponseContract:
]
)
result = translator.from_pulsar(response)
result = translator.encode(response)
assert isinstance(result, dict)
assert "chunks" in result
@ -155,7 +155,7 @@ class TestDocumentEmbeddingsResponseContract:
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
assert result["chunks"][0]["score"] == 0.95
def test_response_translator_from_pulsar_with_empty_chunks(self):
def test_response_translator_encode_with_empty_chunks(self):
"""Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator()
@ -164,25 +164,25 @@ class TestDocumentEmbeddingsResponseContract:
chunks=[]
)
result = translator.from_pulsar(response)
result = translator.encode(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == []
def test_response_translator_from_pulsar_with_none_chunks(self):
def test_response_translator_encode_with_none_chunks(self):
"""Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = None
result = translator.from_pulsar(response)
result = translator.encode(response)
assert isinstance(result, dict)
assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self):
def test_response_translator_encode_with_completion(self):
"""Test response translator with completion flag"""
translator = DocumentEmbeddingsResponseTranslator()
@ -194,7 +194,7 @@ class TestDocumentEmbeddingsResponseContract:
]
)
result, is_final = translator.from_response_with_completion(response)
result, is_final = translator.encode_with_completion(response)
assert isinstance(result, dict)
assert "chunks" in result
@ -202,12 +202,12 @@ class TestDocumentEmbeddingsResponseContract:
assert result["chunks"][0]["chunk_id"] == "chunk1"
assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self):
"""Test that to_pulsar raises NotImplementedError for responses"""
def test_response_translator_decode_not_implemented(self):
"""Test that decode raises NotImplementedError for responses"""
translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
translator.decode({"chunks": [{"chunk_id": "test", "score": 0.9}]})
class TestDocumentEmbeddingsMessageCompatibility:
@ -225,7 +225,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Convert to Pulsar request
req_translator = DocumentEmbeddingsRequestTranslator()
pulsar_request = req_translator.to_pulsar(request_data)
pulsar_request = req_translator.decode(request_data)
# Simulate service processing and creating response
response = DocumentEmbeddingsResponse(
@ -238,7 +238,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Convert response back to dict
resp_translator = DocumentEmbeddingsResponseTranslator()
response_data = resp_translator.from_pulsar(response)
response_data = resp_translator.encode(response)
# Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
@ -261,7 +261,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Convert response to dict
translator = DocumentEmbeddingsResponseTranslator()
response_data = translator.from_pulsar(response)
response_data = translator.encode(response)
# Verify error handling
assert isinstance(response_data, dict)

View file

@ -212,10 +212,11 @@ class TestAgentMessageContracts:
# Test required fields
response = AgentResponse(**response_data)
assert hasattr(response, 'answer')
assert hasattr(response, 'chunk_type')
assert hasattr(response, 'content')
assert hasattr(response, 'end_of_message')
assert hasattr(response, 'end_of_dialog')
assert hasattr(response, 'error')
assert hasattr(response, 'thought')
assert hasattr(response, 'observation')
def test_agent_step_schema_contract(self):
"""Test AgentStep schema contract"""

View file

@ -0,0 +1,177 @@
"""
Contract tests for orchestrator message schemas.
Verifies that AgentRequest/AgentStep with orchestration fields
serialise and deserialise correctly through the Pulsar schema layer.
"""
import pytest
import json
from trustgraph.schema import AgentRequest, AgentStep, PlanStep
@pytest.mark.contract
class TestOrchestrationFieldContracts:
"""Contract tests for orchestration fields on AgentRequest."""
def test_agent_request_orchestration_fields_roundtrip(self):
req = AgentRequest(
question="Test question",
user="testuser",
collection="default",
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
pattern="react",
task_type="research",
framing="Focus on accuracy",
conversation_id="conv-456",
)
assert req.correlation_id == "corr-123"
assert req.parent_session_id == "parent-sess"
assert req.subagent_goal == "What is X?"
assert req.expected_siblings == 4
assert req.pattern == "react"
assert req.task_type == "research"
assert req.framing == "Focus on accuracy"
assert req.conversation_id == "conv-456"
def test_agent_request_orchestration_fields_default_empty(self):
req = AgentRequest(
question="Test question",
user="testuser",
)
assert req.correlation_id == ""
assert req.parent_session_id == ""
assert req.subagent_goal == ""
assert req.expected_siblings == 0
assert req.pattern == ""
assert req.task_type == ""
assert req.framing == ""
@pytest.mark.contract
class TestSubagentCompletionStepContract:
"""Contract tests for subagent-completion step type."""
def test_subagent_completion_step_fields(self):
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation="The answer text",
step_type="subagent-completion",
)
assert step.step_type == "subagent-completion"
assert step.observation == "The answer text"
assert step.thought == "Subagent completed"
assert step.action == "complete"
def test_subagent_completion_in_request_history(self):
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation="answer",
step_type="subagent-completion",
)
req = AgentRequest(
question="goal",
user="testuser",
correlation_id="corr-123",
history=[step],
)
assert len(req.history) == 1
assert req.history[0].step_type == "subagent-completion"
assert req.history[0].observation == "answer"
@pytest.mark.contract
class TestSynthesisStepContract:
"""Contract tests for synthesis step type with subagent_results."""
def test_synthesis_step_with_results(self):
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
step = AgentStep(
thought="All subagents completed",
action="aggregate",
arguments={},
observation=json.dumps(results),
step_type="synthesise",
subagent_results=results,
)
assert step.step_type == "synthesise"
assert step.subagent_results == results
assert json.loads(step.observation) == results
def test_synthesis_request_matches_supervisor_expectations(self):
"""The synthesis request built by the aggregator must be
recognisable by SupervisorPattern._synthesise()."""
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
step = AgentStep(
thought="All subagents completed",
action="aggregate",
arguments={},
observation=json.dumps(results),
step_type="synthesise",
subagent_results=results,
)
req = AgentRequest(
question="Original question",
user="testuser",
pattern="supervisor",
correlation_id="",
session_id="parent-sess",
history=[step],
)
# SupervisorPattern checks for step_type='synthesise' with
# subagent_results
has_results = bool(
req.history
and any(
getattr(h, 'step_type', '') == 'synthesise'
and getattr(h, 'subagent_results', None)
for h in req.history
)
)
assert has_results
# Pattern must be supervisor
assert req.pattern == "supervisor"
# Correlation ID must be empty (not re-intercepted)
assert req.correlation_id == ""
@pytest.mark.contract
class TestPlanStepContract:
"""Contract tests for plan steps in history."""
def test_plan_step_in_history(self):
plan = [
PlanStep(goal="Step 1", tool_hint="knowledge-query",
depends_on=[], status="completed", result="done"),
PlanStep(goal="Step 2", tool_hint="",
depends_on=[0], status="pending", result=""),
]
step = AgentStep(
thought="Created plan",
action="plan",
step_type="plan",
plan=plan,
)
assert step.step_type == "plan"
assert len(step.plan) == 2
assert step.plan[0].goal == "Step 1"
assert step.plan[0].status == "completed"
assert step.plan[1].depends_on == [0]

View file

@ -0,0 +1,129 @@
"""
Contract tests for provenance triple wire format verifies that triples
built by the provenance library can be parsed by the explainability API
through the wire format conversion.
"""
import pytest
from trustgraph.schema import IRI, LITERAL
from trustgraph.provenance import (
agent_decomposition_triples,
agent_finding_triples,
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
)
from trustgraph.api.explainability import (
ExplainEntity,
Decomposition,
Finding,
Plan,
StepResult,
Synthesis,
wire_triples_to_tuples,
)
def _triples_to_wire(triples):
"""Convert provenance Triple objects to the wire format dicts
that the gateway/socket client would produce."""
wire = []
for t in triples:
entry = {
"s": _term_to_wire(t.s),
"p": _term_to_wire(t.p),
"o": _term_to_wire(t.o),
}
wire.append(entry)
return wire
def _term_to_wire(term):
"""Convert a Term to wire format dict."""
if term.type == IRI:
return {"t": "i", "i": term.iri}
elif term.type == LITERAL:
return {"t": "l", "v": term.value}
return {"t": "l", "v": str(term)}
def _roundtrip(triples, uri):
"""Convert triples through wire format and parse via from_triples."""
wire = _triples_to_wire(triples)
tuples = wire_triples_to_tuples(wire)
return ExplainEntity.from_triples(uri, tuples)
@pytest.mark.contract
class TestDecompositionWireFormat:
def test_roundtrip(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session",
["What is X?", "What is Y?"],
)
entity = _roundtrip(triples, "urn:decompose")
assert isinstance(entity, Decomposition)
assert set(entity.goals) == {"What is X?", "What is Y?"}
@pytest.mark.contract
class TestFindingWireFormat:
def test_roundtrip(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
document_id="urn:doc/finding",
)
entity = _roundtrip(triples, "urn:finding")
assert isinstance(entity, Finding)
assert entity.goal == "What is X?"
assert entity.document == "urn:doc/finding"
@pytest.mark.contract
class TestPlanWireFormat:
def test_roundtrip(self):
triples = agent_plan_triples(
"urn:plan", "urn:session",
["Step 1", "Step 2", "Step 3"],
)
entity = _roundtrip(triples, "urn:plan")
assert isinstance(entity, Plan)
assert set(entity.steps) == {"Step 1", "Step 2", "Step 3"}
@pytest.mark.contract
class TestStepResultWireFormat:
def test_roundtrip(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
document_id="urn:doc/step",
)
entity = _roundtrip(triples, "urn:step")
assert isinstance(entity, StepResult)
assert entity.step == "Define X"
assert entity.document == "urn:doc/step"
@pytest.mark.contract
class TestSynthesisWireFormat:
def test_roundtrip(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
document_id="urn:doc/synthesis",
)
entity = _roundtrip(triples, "urn:synthesis")
assert isinstance(entity, Synthesis)
assert entity.document == "urn:doc/synthesis"

View file

@ -33,7 +33,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_session=True"
@ -57,7 +57,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "is_final must be False when end_of_session=False"
@ -80,7 +80,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False
@ -103,7 +103,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
@ -125,7 +125,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_session=True"
@ -147,7 +147,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
@ -168,7 +168,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "is_final must be False when end_of_stream=False"
@ -188,20 +188,18 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
answer="4",
error=None,
thought=None,
observation=None,
chunk_type="answer",
content="4",
end_of_message=True,
end_of_dialog=True
end_of_dialog=True,
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_dialog=True"
assert response_dict["answer"] == "4"
assert response_dict["content"] == "4"
assert response_dict["end_of_dialog"] is True
def test_agent_translator_is_final_with_end_of_dialog_false(self):
@ -212,44 +210,20 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
answer=None,
error=None,
thought="I need to solve this.",
observation=None,
chunk_type="thought",
content="I need to solve this.",
end_of_message=True,
end_of_dialog=False
end_of_dialog=False,
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "is_final must be False when end_of_dialog=False"
assert response_dict["thought"] == "I need to solve this."
assert response_dict["content"] == "I need to solve this."
assert response_dict["end_of_dialog"] is False
def test_agent_translator_is_final_fallback_with_answer(self):
"""
Test that AgentResponseTranslator returns is_final=True
when answer is present (fallback for legacy responses).
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
# Legacy response without end_of_dialog flag
response = AgentResponse(
answer="4",
error=None,
thought=None,
observation=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when answer is present (legacy fallback)"
assert response_dict["answer"] == "4"
def test_agent_translator_intermediate_message_is_not_final(self):
"""
Test that intermediate messages (thought/observation) return is_final=False.
@ -259,32 +233,28 @@ class TestAgentTranslatorCompletionFlags:
# Test thought message
thought_response = AgentResponse(
answer=None,
error=None,
thought="Processing...",
observation=None,
chunk_type="thought",
content="Processing...",
end_of_message=True,
end_of_dialog=False
end_of_dialog=False,
)
# Act
thought_dict, thought_is_final = translator.from_response_with_completion(thought_response)
thought_dict, thought_is_final = translator.encode_with_completion(thought_response)
# Assert
assert thought_is_final is False, "Thought message must not be final"
# Test observation message
observation_response = AgentResponse(
answer=None,
error=None,
thought=None,
observation="Result found",
chunk_type="observation",
content="Result found",
end_of_message=True,
end_of_dialog=False
end_of_dialog=False,
)
# Act
obs_dict, obs_is_final = translator.from_response_with_completion(observation_response)
obs_dict, obs_is_final = translator.encode_with_completion(observation_response)
# Assert
assert obs_is_final is False, "Observation message must not be final"
@ -302,14 +272,10 @@ class TestAgentTranslatorCompletionFlags:
content="",
end_of_message=True,
end_of_dialog=True,
answer=None,
error=None,
thought=None,
observation=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True, "Streaming format must use end_of_dialog for is_final"

View file

@ -9,7 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing.
import pytest
import json
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, ANY, patch
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
@ -187,7 +187,7 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
# Verify tool was executed
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
@ -272,7 +272,7 @@ Args: {{
# Verify correct service was called
if tool_name == "knowledge_query":
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default")
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default", explain_callback=ANY, parent_uri=ANY)
elif tool_name == "text_completion":
mock_flow_context("prompt-request").question.assert_called()
@ -726,7 +726,7 @@ Final Answer: {
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default")
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_knowledge_query_with_custom_collection(self, mock_flow_context):
@ -739,7 +739,7 @@ Final Answer: {
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_knowledge_query_with_none_collection(self, mock_flow_context):
@ -752,7 +752,7 @@ Final Answer: {
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default")
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context):
@ -810,7 +810,7 @@ Args: {
# Verify the custom collection was used
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers")
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_knowledge_query_multiple_collections(self, mock_flow_context):
@ -840,4 +840,4 @@ Args: {
# Verify correct collection was used
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection)
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection, explain_callback=ANY, parent_uri=ANY)

View file

@ -39,7 +39,7 @@ class TestAgentServiceNonStreaming:
mock_agent_manager_class.return_value = mock_agent_instance
# Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming):
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
await think("I need to solve this.", is_final=True)
await observe("The answer is 4.", is_final=True)
return Final(thought="Final answer", final="4")
@ -76,22 +76,33 @@ class TestAgentServiceNonStreaming:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: should have 3 responses (thought, observation, answer)
assert len(sent_responses) == 3, f"Expected 3 responses, got {len(sent_responses)}"
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
]
# Should have explain events for session, iteration, observation, and final
assert len(explain_responses) >= 1, "Expected at least 1 explain event"
# Should have 3 content responses (thought, observation, answer)
assert len(content_responses) == 3, f"Expected 3 content responses, got {len(content_responses)}"
# Check thought message
thought_response = sent_responses[0]
thought_response = content_responses[0]
assert isinstance(thought_response, AgentResponse)
assert thought_response.thought == "I need to solve this."
assert thought_response.answer is None
assert thought_response.chunk_type == "thought"
assert thought_response.content == "I need to solve this."
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
# Check observation message
observation_response = sent_responses[1]
observation_response = content_responses[1]
assert isinstance(observation_response, AgentResponse)
assert observation_response.observation == "The answer is 4."
assert observation_response.answer is None
assert observation_response.chunk_type == "observation"
assert observation_response.content == "The answer is 4."
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
@ -120,7 +131,7 @@ class TestAgentServiceNonStreaming:
mock_agent_manager_class.return_value = mock_agent_instance
# Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming):
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
return Final(thought="Final answer", final="4")
mock_agent_instance.react = mock_react
@ -155,15 +166,25 @@ class TestAgentServiceNonStreaming:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: should have 1 response (final answer)
assert len(sent_responses) == 1, f"Expected 1 response, got {len(sent_responses)}"
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
]
# Should have explain events for session and final
assert len(explain_responses) >= 1, "Expected at least 1 explain event"
# Should have 1 content response (final answer)
assert len(content_responses) == 1, f"Expected 1 content response, got {len(content_responses)}"
# Check final answer message
answer_response = sent_responses[0]
answer_response = content_responses[0]
assert isinstance(answer_response, AgentResponse)
assert answer_response.answer == "4"
assert answer_response.thought is None
assert answer_response.observation is None
assert answer_response.chunk_type == "answer"
assert answer_response.content == "4"
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"

View file

@ -0,0 +1,216 @@
"""
Unit tests for the Aggregator tracks fan-out correlations and triggers
synthesis when all subagents complete.
"""
import time
import pytest
from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(question="Test question", user="testuser",
collection="default", streaming=False,
session_id="parent-session", task_type="research",
framing="test framing", conversation_id="conv-1"):
return AgentRequest(
question=question,
user=user,
collection=collection,
streaming=streaming,
session_id=session_id,
task_type=task_type,
framing=framing,
conversation_id=conversation_id,
)
class TestRegisterFanout:
def test_stores_correlation_entry(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 3)
assert "corr-1" in agg.correlations
entry = agg.correlations["corr-1"]
assert entry["parent_session_id"] == "parent-1"
assert entry["expected"] == 3
assert entry["results"] == {}
def test_stores_request_template(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
entry = agg.correlations["corr-1"]
assert entry["request_template"] is template
def test_records_creation_time(self):
agg = Aggregator()
before = time.time()
agg.register_fanout("corr-1", "parent-1", 2)
after = time.time()
created = agg.correlations["corr-1"]["created_at"]
assert before <= created <= after
class TestRecordCompletion:
def test_returns_false_until_all_done(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 3)
assert agg.record_completion("corr-1", "goal-a", "answer-a") is False
assert agg.record_completion("corr-1", "goal-b", "answer-b") is False
assert agg.record_completion("corr-1", "goal-c", "answer-c") is True
def test_returns_none_for_unknown_correlation(self):
agg = Aggregator()
result = agg.record_completion("unknown", "goal", "answer")
assert result is None
def test_stores_results_by_goal(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 2)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
results = agg.correlations["corr-1"]["results"]
assert results["goal-a"] == "answer-a"
assert results["goal-b"] == "answer-b"
def test_single_subagent(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 1)
assert agg.record_completion("corr-1", "goal-a", "answer") is True
class TestGetOriginalRequest:
def test_peeks_without_consuming(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
result = agg.get_original_request("corr-1")
assert result is template
# Entry still exists
assert "corr-1" in agg.correlations
def test_returns_none_for_unknown(self):
agg = Aggregator()
assert agg.get_original_request("unknown") is None
class TestBuildSynthesisRequest:
def test_builds_correct_request(self):
agg = Aggregator()
template = _make_request(
question="Original question",
streaming=True,
task_type="risk-assessment",
framing="Assess risks",
)
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
assert req.question == "Original question"
assert req.pattern == "supervisor"
assert req.session_id == "parent-1"
assert req.correlation_id == "" # Must be empty
assert req.streaming == True
assert req.task_type == "risk-assessment"
assert req.framing == "Assess risks"
def test_synthesis_step_in_history(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# Last history step should be the synthesis step
assert len(req.history) >= 1
synth_step = req.history[-1]
assert synth_step.step_type == "synthesise"
assert synth_step.subagent_results == {
"goal-a": "answer-a",
"goal-b": "answer-b",
}
def test_consumes_correlation_entry(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 1,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# Entry should be removed
assert "corr-1" not in agg.correlations
def test_raises_for_unknown_correlation(self):
agg = Aggregator()
with pytest.raises(RuntimeError, match="No results"):
agg.build_synthesis_request(
"unknown", "question", "user", "default",
)
class TestCleanupStale:
def test_removes_entries_older_than_timeout(self):
agg = Aggregator(timeout=1)
agg.register_fanout("corr-1", "parent-1", 2)
# Backdate the creation time
agg.correlations["corr-1"]["created_at"] = time.time() - 2
stale = agg.cleanup_stale()
assert "corr-1" in stale
assert "corr-1" not in agg.correlations
def test_keeps_recent_entries(self):
agg = Aggregator(timeout=300)
agg.register_fanout("corr-1", "parent-1", 2)
stale = agg.cleanup_stale()
assert stale == []
assert "corr-1" in agg.correlations
def test_mixed_stale_and_fresh(self):
agg = Aggregator(timeout=1)
agg.register_fanout("stale", "parent-1", 2)
agg.register_fanout("fresh", "parent-2", 2)
agg.correlations["stale"]["created_at"] = time.time() - 2
stale = agg.cleanup_stale()
assert "stale" in stale
assert "stale" not in agg.correlations
assert "fresh" in agg.correlations

View file

@ -0,0 +1,122 @@
"""
Tests that streaming callbacks set message_id on AgentResponse.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.pattern_base import PatternBase
from trustgraph.schema import AgentResponse
@pytest.fixture
def pattern():
processor = MagicMock()
return PatternBase(processor)
class TestThinkCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_think_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/thought"
think = pattern.make_think_callback(capture, streaming=True, message_id=msg_id)
await think("hello", is_final=False)
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "thought"
@pytest.mark.asyncio
async def test_non_streaming_think_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/thought"
think = pattern.make_think_callback(capture, streaming=False, message_id=msg_id)
await think("hello")
assert responses[0].message_id == msg_id
assert responses[0].end_of_message is True
class TestObserveCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_observe_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/observation"
observe = pattern.make_observe_callback(capture, streaming=True, message_id=msg_id)
await observe("result", is_final=True)
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "observation"
class TestAnswerCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_answer_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
answer = pattern.make_answer_callback(capture, streaming=True, message_id=msg_id)
await answer("the answer")
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "answer"
@pytest.mark.asyncio
async def test_no_message_id_default(self, pattern):
responses = []
async def capture(r):
responses.append(r)
answer = pattern.make_answer_callback(capture, streaming=True)
await answer("the answer")
assert responses[0].message_id == ""
class TestSendFinalResponseMessageId:
@pytest.mark.asyncio
async def test_streaming_final_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
await pattern.send_final_response(
capture, streaming=True, answer_text="answer",
message_id=msg_id,
)
# Should get content chunk + end-of-dialog marker
assert all(r.message_id == msg_id for r in responses)
@pytest.mark.asyncio
async def test_non_streaming_final_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
await pattern.send_final_response(
capture, streaming=False, answer_text="answer",
message_id=msg_id,
)
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].end_of_dialog is True

View file

@ -0,0 +1,174 @@
"""
Unit tests for completion dispatch verifies that agent_request() in the
orchestrator service correctly intercepts subagent completion messages and
routes them to _handle_subagent_completion.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
return AgentRequest(**defaults)
def _make_completion_request(correlation_id, goal, answer):
"""Build a completion request as emit_subagent_completion would."""
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation=answer,
step_type="subagent-completion",
)
return _make_request(
correlation_id=correlation_id,
parent_session_id="parent-sess",
subagent_goal=goal,
expected_siblings=2,
history=[step],
)
class TestCompletionDetection:
"""Test that completion messages are correctly identified."""
def test_is_completion_when_correlation_id_and_step_type(self):
req = _make_completion_request("corr-1", "goal-a", "answer-a")
has_correlation = bool(getattr(req, 'correlation_id', ''))
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in req.history
)
assert has_correlation
assert is_completion
def test_not_completion_without_correlation_id(self):
step = AgentStep(
step_type="subagent-completion",
observation="answer",
)
req = _make_request(
correlation_id="",
history=[step],
)
has_correlation = bool(getattr(req, 'correlation_id', ''))
assert not has_correlation
def test_not_completion_without_step_type(self):
step = AgentStep(
step_type="react",
observation="answer",
)
req = _make_request(
correlation_id="corr-1",
history=[step],
)
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in req.history
)
assert not is_completion
def test_not_completion_with_empty_history(self):
req = _make_request(
correlation_id="corr-1",
history=[],
)
assert not req.history
class TestAggregatorIntegration:
"""Test the aggregator flow as used by _handle_subagent_completion."""
def test_full_completion_flow(self):
"""Simulates the flow: register, record completions, build synthesis."""
agg = Aggregator()
template = _make_request(
question="Original question",
streaming=True,
task_type="risk-assessment",
framing="Assess risks",
session_id="parent-sess",
)
# Register fan-out
agg.register_fanout("corr-1", "parent-sess", 2,
request_template=template)
# First completion — not all done
all_done = agg.record_completion(
"corr-1", "goal-a", "answer-a",
)
assert all_done is False
# Second completion — all done
all_done = agg.record_completion(
"corr-1", "goal-b", "answer-b",
)
assert all_done is True
# Peek at template
peeked = agg.get_original_request("corr-1")
assert peeked.question == "Original question"
# Build synthesis request
synth = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
# Verify synthesis request
assert synth.pattern == "supervisor"
assert synth.correlation_id == ""
assert synth.session_id == "parent-sess"
assert synth.streaming is True
# Verify synthesis history has results
synth_steps = [
s for s in synth.history
if getattr(s, 'step_type', '') == 'synthesise'
]
assert len(synth_steps) == 1
assert synth_steps[0].subagent_results == {
"goal-a": "answer-a",
"goal-b": "answer-b",
}
def test_synthesis_request_not_detected_as_completion(self):
"""The synthesis request must not be intercepted as a completion."""
agg = Aggregator()
template = _make_request(session_id="parent-sess")
agg.register_fanout("corr-1", "parent-sess", 1,
request_template=template)
agg.record_completion("corr-1", "goal", "answer")
synth = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# correlation_id must be empty so it's not intercepted
assert synth.correlation_id == ""
# Even if we check for completion step, shouldn't match
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in synth.history
)
assert not is_completion

View file

@ -0,0 +1,177 @@
"""
Unit tests for explainability API parsing verifies that from_triples()
correctly dispatches and parses the new orchestrator entity types.
"""
import pytest
from trustgraph.api.explainability import (
ExplainEntity,
Decomposition,
Finding,
Plan,
StepResult,
Synthesis,
Analysis,
Observation,
Conclusion,
TG_DECOMPOSITION,
TG_FINDING,
TG_PLAN_TYPE,
TG_STEP_RESULT,
TG_SYNTHESIS,
TG_ANSWER_TYPE,
TG_OBSERVATION_TYPE,
TG_TOOL_USE,
TG_ANALYSIS,
TG_CONCLUSION,
TG_DOCUMENT,
TG_SUBAGENT_GOAL,
TG_PLAN_STEP,
RDF_TYPE,
)
PROV_ENTITY = "http://www.w3.org/ns/prov#Entity"
def _make_triples(uri, types, extras=None):
"""Build a list of (s, p, o) tuples for testing."""
triples = [(uri, RDF_TYPE, t) for t in types]
if extras:
triples.extend((uri, p, o) for p, o in extras)
return triples
class TestFromTriplesDispatch:
def test_dispatches_decomposition(self):
triples = _make_triples("urn:d", [PROV_ENTITY, TG_DECOMPOSITION])
entity = ExplainEntity.from_triples("urn:d", triples)
assert isinstance(entity, Decomposition)
def test_dispatches_finding(self):
triples = _make_triples("urn:f",
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:f", triples)
assert isinstance(entity, Finding)
def test_dispatches_plan(self):
triples = _make_triples("urn:p", [PROV_ENTITY, TG_PLAN_TYPE])
entity = ExplainEntity.from_triples("urn:p", triples)
assert isinstance(entity, Plan)
def test_dispatches_step_result(self):
triples = _make_triples("urn:sr",
[PROV_ENTITY, TG_STEP_RESULT, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:sr", triples)
assert isinstance(entity, StepResult)
def test_dispatches_synthesis(self):
triples = _make_triples("urn:s",
[PROV_ENTITY, TG_SYNTHESIS, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:s", triples)
assert isinstance(entity, Synthesis)
def test_dispatches_analysis_unchanged(self):
triples = _make_triples("urn:a", [PROV_ENTITY, TG_ANALYSIS])
entity = ExplainEntity.from_triples("urn:a", triples)
assert isinstance(entity, Analysis)
def test_dispatches_analysis_with_tooluse(self):
"""Analysis+ToolUse mixin still dispatches to Analysis."""
triples = _make_triples("urn:a",
[PROV_ENTITY, TG_ANALYSIS, TG_TOOL_USE])
entity = ExplainEntity.from_triples("urn:a", triples)
assert isinstance(entity, Analysis)
def test_dispatches_observation(self):
triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE])
entity = ExplainEntity.from_triples("urn:o", triples)
assert isinstance(entity, Observation)
def test_dispatches_conclusion_unchanged(self):
triples = _make_triples("urn:c",
[PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:c", triples)
assert isinstance(entity, Conclusion)
def test_finding_takes_precedence_over_synthesis(self):
"""Finding has Answer mixin but should dispatch to Finding, not
Synthesis, because Finding is checked first."""
triples = _make_triples("urn:f",
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:f", triples)
assert isinstance(entity, Finding)
assert not isinstance(entity, Synthesis)
class TestDecompositionParsing:
def test_parses_goals(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION], [
(TG_SUBAGENT_GOAL, "What is X?"),
(TG_SUBAGENT_GOAL, "What is Y?"),
])
entity = Decomposition.from_triples("urn:d", triples)
assert set(entity.goals) == {"What is X?", "What is Y?"}
def test_entity_type_field(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
entity = Decomposition.from_triples("urn:d", triples)
assert entity.entity_type == "decomposition"
def test_empty_goals(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
entity = Decomposition.from_triples("urn:d", triples)
assert entity.goals == []
class TestFindingParsing:
def test_parses_goal_and_document(self):
triples = _make_triples("urn:f", [TG_FINDING, TG_ANSWER_TYPE], [
(TG_SUBAGENT_GOAL, "What is X?"),
(TG_DOCUMENT, "urn:doc/finding"),
])
entity = Finding.from_triples("urn:f", triples)
assert entity.goal == "What is X?"
assert entity.document == "urn:doc/finding"
def test_entity_type_field(self):
triples = _make_triples("urn:f", [TG_FINDING])
entity = Finding.from_triples("urn:f", triples)
assert entity.entity_type == "finding"
class TestPlanParsing:
def test_parses_steps(self):
triples = _make_triples("urn:p", [TG_PLAN_TYPE], [
(TG_PLAN_STEP, "Define X"),
(TG_PLAN_STEP, "Research Y"),
(TG_PLAN_STEP, "Analyse Z"),
])
entity = Plan.from_triples("urn:p", triples)
assert set(entity.steps) == {"Define X", "Research Y", "Analyse Z"}
def test_entity_type_field(self):
triples = _make_triples("urn:p", [TG_PLAN_TYPE])
entity = Plan.from_triples("urn:p", triples)
assert entity.entity_type == "plan"
class TestStepResultParsing:
def test_parses_step_and_document(self):
triples = _make_triples("urn:sr", [TG_STEP_RESULT, TG_ANSWER_TYPE], [
(TG_PLAN_STEP, "Define X"),
(TG_DOCUMENT, "urn:doc/step"),
])
entity = StepResult.from_triples("urn:sr", triples)
assert entity.step == "Define X"
assert entity.document == "urn:doc/step"
def test_entity_type_field(self):
triples = _make_triples("urn:sr", [TG_STEP_RESULT])
entity = StepResult.from_triples("urn:sr", triples)
assert entity.entity_type == "step-result"

View file

@ -0,0 +1,289 @@
"""
Unit tests for the MetaRouter task type identification and pattern selection.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
def _make_config(patterns=None, task_types=None):
"""Build a config dict as the config service would provide."""
config = {}
if patterns:
config["agent-pattern"] = {
pid: json.dumps(pdata) for pid, pdata in patterns.items()
}
if task_types:
config["agent-task-type"] = {
tid: json.dumps(tdata) for tid, tdata in task_types.items()
}
return config
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response)
def context(service_name):
return client
return context
SAMPLE_PATTERNS = {
"react": {"name": "react", "description": "ReAct pattern"},
"plan-then-execute": {"name": "plan-then-execute", "description": "Plan pattern"},
"supervisor": {"name": "supervisor", "description": "Supervisor pattern"},
}
SAMPLE_TASK_TYPES = {
"general": {
"name": "general",
"description": "General queries",
"valid_patterns": ["react", "plan-then-execute", "supervisor"],
"framing": "",
},
"research": {
"name": "research",
"description": "Research queries",
"valid_patterns": ["react", "plan-then-execute"],
"framing": "Focus on gathering information.",
},
"summarisation": {
"name": "summarisation",
"description": "Summarisation queries",
"valid_patterns": ["react"],
"framing": "Focus on concise synthesis.",
},
}
class TestMetaRouterInit:
def test_defaults_when_no_config(self):
router = MetaRouter()
assert "react" in router.patterns
assert "general" in router.task_types
def test_loads_patterns_from_config(self):
config = _make_config(patterns=SAMPLE_PATTERNS)
router = MetaRouter(config=config)
assert set(router.patterns.keys()) == {"react", "plan-then-execute", "supervisor"}
def test_loads_task_types_from_config(self):
config = _make_config(task_types=SAMPLE_TASK_TYPES)
router = MetaRouter(config=config)
assert set(router.task_types.keys()) == {"general", "research", "summarisation"}
def test_handles_invalid_json_in_config(self):
config = {
"agent-pattern": {"react": "not valid json"},
}
router = MetaRouter(config=config)
assert "react" in router.patterns
assert router.patterns["react"]["name"] == "react"
class TestIdentifyTaskType:
@pytest.mark.asyncio
async def test_skips_llm_when_single_task_type(self):
router = MetaRouter() # Only "general"
context = _make_context("should not be called")
task_type, framing = await router.identify_task_type(
"test question", context,
)
assert task_type == "general"
@pytest.mark.asyncio
async def test_uses_llm_when_multiple_task_types(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("research")
task_type, framing = await router.identify_task_type(
"Research the topic", context,
)
assert task_type == "research"
assert framing == "Focus on gathering information."
@pytest.mark.asyncio
async def test_handles_llm_returning_quoted_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context('"summarisation"')
task_type, _ = await router.identify_task_type(
"Summarise this", context,
)
assert task_type == "summarisation"
@pytest.mark.asyncio
async def test_falls_back_on_unknown_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("nonexistent-type")
task_type, _ = await router.identify_task_type(
"test question", context,
)
assert task_type == DEFAULT_TASK_TYPE
@pytest.mark.asyncio
async def test_falls_back_on_llm_error(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
client = AsyncMock()
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
context = lambda name: client
task_type, _ = await router.identify_task_type(
"test question", context,
)
assert task_type == DEFAULT_TASK_TYPE
class TestSelectPattern:
@pytest.mark.asyncio
async def test_skips_llm_when_single_valid_pattern(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("should not be called")
# summarisation only has ["react"]
pattern = await router.select_pattern(
"Summarise this", "summarisation", context,
)
assert pattern == "react"
@pytest.mark.asyncio
async def test_uses_llm_when_multiple_valid_patterns(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("plan-then-execute")
# research has ["react", "plan-then-execute"]
pattern = await router.select_pattern(
"Research this", "research", context,
)
assert pattern == "plan-then-execute"
@pytest.mark.asyncio
async def test_respects_valid_patterns_constraint(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
# LLM returns supervisor, but research doesn't allow it
context = _make_context("supervisor")
pattern = await router.select_pattern(
"Research this", "research", context,
)
# Should fall back to first valid pattern
assert pattern == "react"
@pytest.mark.asyncio
async def test_falls_back_on_llm_error(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
client = AsyncMock()
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
context = lambda name: client
# general has ["react", "plan-then-execute", "supervisor"]
pattern = await router.select_pattern(
"test", "general", context,
)
# Falls back to first valid pattern
assert pattern == "react"
@pytest.mark.asyncio
async def test_falls_back_to_default_for_unknown_task_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("react")
# Unknown task type — valid_patterns falls back to all patterns
pattern = await router.select_pattern(
"test", "unknown-type", context,
)
assert pattern == "react"
class TestRoute:
@pytest.mark.asyncio
async def test_full_routing_pipeline(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
# Mock context where prompt returns different values per call
client = AsyncMock()
call_count = 0
async def mock_prompt(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return "research" # task type
return "plan-then-execute" # pattern
client.prompt = mock_prompt
context = lambda name: client
pattern, task_type, framing = await router.route(
"Research the relationships", context,
)
assert task_type == "research"
assert pattern == "plan-then-execute"
assert framing == "Focus on gathering information."

View file

@ -0,0 +1,132 @@
"""
Tests for the on_action callback in react() verifies that it fires
after action selection but before tool execution.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.types import Action, Final, Tool, Argument
class TestOnActionCallback:
@pytest.mark.asyncio
async def test_on_action_called_for_tool_use(self):
"""on_action fires when react() selects a tool (not Final)."""
call_log = []
async def fake_on_action(act):
call_log.append(("on_action", act.name))
# Tool that records when it's invoked
async def tool_invoke(**kwargs):
call_log.append(("tool_invoke",))
return "tool result"
tool_impl = MagicMock()
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
tools = {
"search": Tool(
name="search",
description="Search",
implementation=tool_impl,
arguments=[Argument(name="query", type="string", description="q")],
config={},
),
}
agent = AgentManager(tools=tools)
# Mock reason() to return an Action
action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="")
agent.reason = AsyncMock(return_value=action)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
on_action=fake_on_action,
)
# on_action should fire before tool_invoke
assert len(call_log) == 2
assert call_log[0] == ("on_action", "search")
assert call_log[1] == ("tool_invoke",)
@pytest.mark.asyncio
async def test_on_action_not_called_for_final(self):
"""on_action does not fire when react() returns Final."""
called = []
async def fake_on_action(act):
called.append(act)
agent = AgentManager(tools={})
agent.reason = AsyncMock(
return_value=Final(thought="done", final="answer")
)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
result = await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
on_action=fake_on_action,
)
assert isinstance(result, Final)
assert len(called) == 0
@pytest.mark.asyncio
async def test_on_action_none_accepted(self):
"""react() works fine when on_action is None (default)."""
async def tool_invoke(**kwargs):
return "result"
tool_impl = MagicMock()
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
tools = {
"search": Tool(
name="search",
description="Search",
implementation=tool_impl,
arguments=[],
config={},
),
}
agent = AgentManager(tools=tools)
agent.reason = AsyncMock(
return_value=Action(thought="t", name="search", arguments={}, observation="")
)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
result = await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
# on_action not passed — defaults to None
)
assert isinstance(result, Action)
assert result.observation == "result"

View file

@ -0,0 +1,655 @@
"""
Integration tests for agent-orchestrator provenance chains.
Tests all three patterns by calling iterate() with mocked dependencies
and verifying the explain events emitted via respond().
Provenance chains:
React: session iteration (observation or final)
Plan: session plan step-result(s) synthesis
Supervisor: session decomposition finding(s) synthesis
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
GRAPH_RETRIEVAL,
)
# Agent provenance type constants
from trustgraph.provenance.namespaces import (
TG_AGENT_QUESTION,
TG_ANALYSIS,
TG_TOOL_USE,
TG_OBSERVATION_TYPE,
TG_CONCLUSION,
TG_DECOMPOSITION,
TG_FINDING,
TG_PLAN_TYPE,
TG_STEP_RESULT,
TG_SYNTHESIS as TG_AGENT_SYNTHESIS,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
def collect_explain_events(respond_mock):
"""Extract explain events from a respond mock's call history."""
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
events.append({
"explain_id": resp.explain_id,
"explain_graph": resp.explain_graph,
"triples": resp.explain_triples,
})
return events
# ---------------------------------------------------------------------------
# Mock processor
# ---------------------------------------------------------------------------
def make_mock_processor(tools=None):
"""Build a mock processor with the minimal interface patterns need."""
processor = MagicMock()
processor.max_iterations = 10
processor.save_answer_content = AsyncMock()
# provenance_session_uri must return a real URI
def mock_session_uri(session_id):
return f"urn:trustgraph:agent:session:{session_id}"
processor.provenance_session_uri.side_effect = mock_session_uri
# Agent with tools
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agent = agent
# Aggregator for supervisor
processor.aggregator = MagicMock()
return processor
def make_mock_flow():
"""Build a mock flow that returns async mock producers."""
producers = {}
def flow_factory(name):
if name not in producers:
producers[name] = AsyncMock()
return producers[name]
flow = MagicMock(side_effect=flow_factory)
flow._producers = producers
return flow
def make_base_request(**kwargs):
"""Build a minimal AgentRequest."""
defaults = dict(
question="What is quantum computing?",
state="",
group=[],
history=[],
user="testuser",
collection="default",
streaming=False,
session_id="test-session-123",
conversation_id="",
pattern="react",
task_type="",
framing="",
correlation_id="",
parent_session_id="",
subagent_goal="",
expected_siblings=0,
)
defaults.update(kwargs)
return AgentRequest(**defaults)
# ---------------------------------------------------------------------------
# React pattern tests
# ---------------------------------------------------------------------------
class TestReactPatternProvenance:
"""
React pattern chain: session iteration final
(single iteration ending in Final answer)
"""
@pytest.mark.asyncio
async def test_single_iteration_final_answer(self):
"""
A single react iteration that produces a Final answer should emit:
session, iteration, final in that order.
"""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action, Final
processor = make_mock_processor()
pattern = ReactPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
request = make_base_request()
# Mock AgentManager.react to call on_action then return Final
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
final = Final(
thought="I know the answer",
final="Quantum computing uses qubits.",
)
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
thought="I know the answer",
name="final",
arguments={},
observation="",
))
return final
mock_am.react.side_effect = mock_react
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 3 events: session, iteration, final
assert len(events) == 3, (
f"Expected 3 explain events (session, iteration, final), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
# Check types
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_CONCLUSION)
# Check derivation chain
all_triples = []
for e in events:
all_triples.extend(e["triples"])
uris = [e["explain_id"] for e in events]
# iteration derives from session
assert derived_from(all_triples, uris[1]) == uris[0]
# final derives from session (first iteration, no prior observation)
assert derived_from(all_triples, uris[2]) == uris[0]
@pytest.mark.asyncio
async def test_iteration_with_tool_call(self):
"""
A react iteration that calls a tool (not Final) should emit:
session, iteration, observation then call next() for continuation.
"""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action
# Create a mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query the knowledge base"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="The answer is 42")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = make_mock_processor(
tools={"knowledge-query": mock_tool}
)
pattern = ReactPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
request = make_base_request()
action = Action(
thought="I need to look this up",
name="knowledge-query",
arguments={"question": "What is quantum computing?"},
observation="Quantum computing uses qubits.",
)
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
if on_action:
await on_action(action)
return action
mock_am.react.side_effect = mock_react
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 3 events: session, iteration, observation
assert len(events) == 3, (
f"Expected 3 explain events (session, iteration, observation), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_OBSERVATION_TYPE)
# next() should have been called to continue the loop
assert next_fn.called
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All explain triples should be in urn:graph:retrieval."""
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.react.types import Action, Final
processor = make_mock_processor()
pattern = ReactPattern(processor)
respond = AsyncMock()
flow = make_mock_flow()
with patch(
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
if on_action:
await on_action(Action(
thought="done", name="final",
arguments={}, observation="",
))
return Final(thought="done", final="answer")
mock_am.react.side_effect = mock_react
await pattern.iterate(
make_base_request(), respond, AsyncMock(), flow,
)
for event in collect_explain_events(respond):
for t in event["triples"]:
assert t.g == GRAPH_RETRIEVAL
# ---------------------------------------------------------------------------
# Plan-then-execute pattern tests
# ---------------------------------------------------------------------------
class TestPlanPatternProvenance:
"""
Plan pattern chain:
Planning iteration: session plan
Execution iterations: step-result(s) synthesis
"""
@pytest.mark.asyncio
async def test_planning_iteration_emits_session_and_plan(self):
"""
The first iteration (planning) should emit:
session, plan then call next() with the plan in history.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
processor = make_mock_processor()
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="plan")
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 2 events: session, plan
assert len(events) == 2, (
f"Expected 2 explain events (session, plan), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_PLAN_TYPE)
# Plan should derive from session
all_triples = []
for e in events:
all_triples.extend(e["triples"])
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
# next() should have been called with plan in history
assert next_fn.called
@pytest.mark.asyncio
async def test_execution_iteration_emits_step_result(self):
"""
An execution iteration should emit a step-result event.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
# Create a mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query KB"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="Found the answer")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = make_mock_processor(
tools={"knowledge-query": mock_tool}
)
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = {
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
}
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with plan already in history (second iteration)
plan_step = AgentStep(
thought="Created plan",
action="plan",
arguments={},
observation="[]",
step_type="plan",
plan=[
PlanStep(goal="Find info", tool_hint="knowledge-query",
depends_on=[], status="pending", result=""),
],
)
request = make_base_request(
pattern="plan",
history=[plan_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have step-result (no session on iteration > 1)
step_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_STEP_RESULT)
]
assert len(step_events) == 1, (
f"Expected 1 step-result event, got {len(step_events)}"
)
@pytest.mark.asyncio
async def test_synthesis_after_all_steps_complete(self):
"""
When all plan steps are completed, synthesis should be emitted.
"""
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
processor = make_mock_processor()
pattern = PlanThenExecutePattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer."
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with all steps completed
exec_step = AgentStep(
thought="Executing step",
action="knowledge-query",
arguments={},
observation="Result",
step_type="execute",
plan=[
PlanStep(goal="Find info", tool_hint="knowledge-query",
depends_on=[], status="completed", result="Found it"),
],
)
request = make_base_request(
pattern="plan",
history=[exec_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have synthesis event
synth_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
]
assert len(synth_events) == 1, (
f"Expected 1 synthesis event, got {len(synth_events)}"
)
# ---------------------------------------------------------------------------
# Supervisor pattern tests
# ---------------------------------------------------------------------------
class TestSupervisorPatternProvenance:
"""
Supervisor pattern chain:
Decompose: session decomposition
(Fan-out to subagents happens externally)
Synthesise: synthesis (derives from findings)
"""
@pytest.mark.asyncio
async def test_decompose_emits_session_and_decomposition(self):
"""
The decompose phase should emit: session, decomposition.
"""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
"What is quantum computing?",
"What are qubits?",
]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="supervisor")
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have 2 events: session, decomposition
assert len(events) == 2, (
f"Expected 2 explain events (session, decomposition), "
f"got {len(events)}: {[e['explain_id'] for e in events]}"
)
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_DECOMPOSITION)
# Decomposition derives from session
all_triples = []
for e in events:
all_triples.extend(e["triples"])
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
@pytest.mark.asyncio
async def test_synthesis_emits_after_subagent_results(self):
"""
When subagent results arrive, synthesis should be emitted.
"""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer."
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
# Request with subagent results in history
synth_step = AgentStep(
thought="",
action="synthesise",
arguments={},
observation="",
step_type="synthesise",
subagent_results={
"What is quantum computing?": "It uses qubits",
"What are qubits?": "Quantum bits",
},
)
request = make_base_request(
pattern="supervisor",
history=[synth_step],
)
await pattern.iterate(request, respond, next_fn, flow)
events = collect_explain_events(respond)
# Should have synthesis event (no session on iteration > 1)
synth_events = [
e for e in events
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
]
assert len(synth_events) == 1
@pytest.mark.asyncio
async def test_decompose_fans_out_subagents(self):
"""The decompose phase should call next() for each subagent goal."""
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = make_mock_processor()
pattern = SupervisorPattern(processor)
respond = AsyncMock()
next_fn = AsyncMock()
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = make_base_request(pattern="supervisor")
await pattern.iterate(request, respond, next_fn, flow)
# 3 subagent requests fanned out
assert next_fn.call_count == 3

View file

@ -0,0 +1,74 @@
"""
Tests that _parse_chunk propagates message_id from wire format
to AgentThought, AgentObservation, and AgentAnswer.
"""
import pytest
from trustgraph.api.socket_client import SocketClient
from trustgraph.api.types import AgentThought, AgentObservation, AgentAnswer
@pytest.fixture
def client():
# We only need _parse_chunk — don't connect
c = object.__new__(SocketClient)
return c
class TestParseChunkMessageId:
def test_thought_message_id(self, client):
resp = {
"chunk_type": "thought",
"content": "thinking...",
"end_of_message": False,
"message_id": "urn:trustgraph:agent:sess/i1/thought",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentThought)
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/thought"
def test_observation_message_id(self, client):
resp = {
"chunk_type": "observation",
"content": "result",
"end_of_message": True,
"message_id": "urn:trustgraph:agent:sess/i1/observation",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentObservation)
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/observation"
def test_answer_message_id(self, client):
resp = {
"chunk_type": "answer",
"content": "the answer",
"end_of_message": False,
"end_of_dialog": False,
"message_id": "urn:trustgraph:agent:sess/final",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentAnswer)
assert chunk.message_id == "urn:trustgraph:agent:sess/final"
def test_thought_missing_message_id(self, client):
resp = {
"chunk_type": "thought",
"content": "thinking...",
"end_of_message": False,
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentThought)
assert chunk.message_id == ""
def test_answer_missing_message_id(self, client):
resp = {
"chunk_type": "answer",
"content": "answer",
"end_of_message": True,
"end_of_dialog": True,
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentAnswer)
assert chunk.message_id == ""

View file

@ -0,0 +1,144 @@
"""
Unit tests for PatternBase subagent helpers is_subagent() and
emit_subagent_completion().
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from dataclasses import dataclass
from trustgraph.schema import AgentRequest
from trustgraph.agent.orchestrator.pattern_base import PatternBase
@dataclass
class MockProcessor:
"""Minimal processor mock for PatternBase."""
pass
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
return AgentRequest(**defaults)
def _make_pattern():
return PatternBase(MockProcessor())
class TestIsSubagent:
def test_returns_true_when_correlation_id_set(self):
pattern = _make_pattern()
request = _make_request(correlation_id="corr-123")
assert pattern.is_subagent(request) is True
def test_returns_false_when_correlation_id_empty(self):
pattern = _make_pattern()
request = _make_request(correlation_id="")
assert pattern.is_subagent(request) is False
def test_returns_false_when_correlation_id_missing(self):
pattern = _make_pattern()
request = _make_request()
assert pattern.is_subagent(request) is False
class TestEmitSubagentCompletion:
@pytest.mark.asyncio
async def test_calls_next_with_completion_request(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "The answer is Y",
)
next_fn.assert_called_once()
completion_req = next_fn.call_args[0][0]
assert isinstance(completion_req, AgentRequest)
@pytest.mark.asyncio
async def test_completion_has_correct_step_type(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="What is X?",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer text",
)
completion_req = next_fn.call_args[0][0]
assert len(completion_req.history) == 1
step = completion_req.history[0]
assert step.step_type == "subagent-completion"
@pytest.mark.asyncio
async def test_completion_carries_answer_in_observation(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="What is X?",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "The answer is Y",
)
completion_req = next_fn.call_args[0][0]
step = completion_req.history[0]
assert step.observation == "The answer is Y"
@pytest.mark.asyncio
async def test_completion_preserves_correlation_fields(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer",
)
completion_req = next_fn.call_args[0][0]
assert completion_req.correlation_id == "corr-123"
assert completion_req.parent_session_id == "parent-sess"
assert completion_req.subagent_goal == "What is X?"
assert completion_req.expected_siblings == 4
@pytest.mark.asyncio
async def test_completion_has_empty_pattern(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="goal",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer",
)
completion_req = next_fn.call_args[0][0]
assert completion_req.pattern == ""

View file

@ -0,0 +1,226 @@
"""
Unit tests for orchestrator provenance triple builders.
"""
import pytest
from trustgraph.provenance import (
agent_decomposition_triples,
agent_finding_triples,
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
)
def _triple_set(triples):
"""Convert triples to a set of (s_iri, p_iri, o_value) for easy assertion."""
result = set()
for t in triples:
s = t.s.iri
p = t.p.iri
o = t.o.iri if t.o.iri else t.o.value
result.add((s, p, o))
return result
def _has_type(triples, uri, rdf_type):
"""Check if a URI has a given rdf:type in the triples."""
return (uri, RDF_TYPE, rdf_type) in _triple_set(triples)
def _get_values(triples, uri, predicate):
"""Get all object values for a given subject + predicate."""
ts = _triple_set(triples)
return [o for s, p, o in ts if s == uri and p == predicate]
class TestDecompositionTriples:
def test_has_correct_types(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a", "goal-b"],
)
assert _has_type(triples, "urn:decompose", PROV_ENTITY)
assert _has_type(triples, "urn:decompose", TG_DECOMPOSITION)
def test_not_answer_type(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a"],
)
assert not _has_type(triples, "urn:decompose", TG_ANSWER_TYPE)
def test_links_to_session(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a"],
)
ts = _triple_set(triples)
assert ("urn:decompose", PROV_WAS_DERIVED_FROM, "urn:session") in ts
def test_includes_goals(self):
goals = ["What is X?", "What is Y?", "What is Z?"]
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", goals,
)
values = _get_values(triples, "urn:decompose", TG_SUBAGENT_GOAL)
assert set(values) == set(goals)
def test_label_includes_count(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["a", "b", "c"],
)
labels = _get_values(triples, "urn:decompose", RDFS_LABEL)
assert any("3" in label for label in labels)
class TestFindingTriples:
def test_has_correct_types(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
assert _has_type(triples, "urn:finding", PROV_ENTITY)
assert _has_type(triples, "urn:finding", TG_FINDING)
assert _has_type(triples, "urn:finding", TG_ANSWER_TYPE)
def test_links_to_decomposition(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
ts = _triple_set(triples)
assert ("urn:finding", PROV_WAS_DERIVED_FROM, "urn:decompose") in ts
def test_includes_goal(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
values = _get_values(triples, "urn:finding", TG_SUBAGENT_GOAL)
assert "What is X?" in values
def test_includes_document_when_provided(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "goal",
document_id="urn:doc/1",
)
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
assert "urn:doc/1" in values
def test_no_document_when_none(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "goal",
)
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
assert values == []
class TestPlanTriples:
def test_has_correct_types(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
assert _has_type(triples, "urn:plan", PROV_ENTITY)
assert _has_type(triples, "urn:plan", TG_PLAN_TYPE)
def test_not_answer_type(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
assert not _has_type(triples, "urn:plan", TG_ANSWER_TYPE)
def test_links_to_session(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
ts = _triple_set(triples)
assert ("urn:plan", PROV_WAS_DERIVED_FROM, "urn:session") in ts
def test_includes_steps(self):
steps = ["Define X", "Research Y", "Analyse Z"]
triples = agent_plan_triples(
"urn:plan", "urn:session", steps,
)
values = _get_values(triples, "urn:plan", TG_PLAN_STEP)
assert set(values) == set(steps)
def test_label_includes_count(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["a", "b"],
)
labels = _get_values(triples, "urn:plan", RDFS_LABEL)
assert any("2" in label for label in labels)
class TestStepResultTriples:
def test_has_correct_types(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
assert _has_type(triples, "urn:step", PROV_ENTITY)
assert _has_type(triples, "urn:step", TG_STEP_RESULT)
assert _has_type(triples, "urn:step", TG_ANSWER_TYPE)
def test_links_to_plan(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
ts = _triple_set(triples)
assert ("urn:step", PROV_WAS_DERIVED_FROM, "urn:plan") in ts
def test_includes_goal(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
values = _get_values(triples, "urn:step", TG_PLAN_STEP)
assert "Define X" in values
def test_includes_document_when_provided(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "goal",
document_id="urn:doc/step",
)
values = _get_values(triples, "urn:step", TG_DOCUMENT)
assert "urn:doc/step" in values
class TestSynthesisTriples:
def test_has_correct_types(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
)
assert _has_type(triples, "urn:synthesis", PROV_ENTITY)
assert _has_type(triples, "urn:synthesis", TG_SYNTHESIS)
assert _has_type(triples, "urn:synthesis", TG_ANSWER_TYPE)
def test_links_to_previous(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:last-finding",
)
ts = _triple_set(triples)
assert ("urn:synthesis", PROV_WAS_DERIVED_FROM,
"urn:last-finding") in ts
def test_includes_document_when_provided(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
document_id="urn:doc/synthesis",
)
values = _get_values(triples, "urn:synthesis", TG_DOCUMENT)
assert "urn:doc/synthesis" in values
def test_label_is_synthesis(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
)
labels = _get_values(triples, "urn:synthesis", RDFS_LABEL)
assert "Synthesis" in labels

View file

@ -0,0 +1,323 @@
"""
Tests for AsyncProcessor config notify pattern:
- register_config_handler with types filtering
- on_config_notify version comparison and type matching
- fetch_config with short-lived client
- fetch_and_apply_config retry logic
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, Mock
from trustgraph.schema import Term, IRI, LITERAL
# Patch heavy dependencies before importing AsyncProcessor
@pytest.fixture
def processor():
"""Create an AsyncProcessor with mocked dependencies."""
with patch('trustgraph.base.async_processor.get_pubsub') as mock_pubsub, \
patch('trustgraph.base.async_processor.Consumer') as mock_consumer, \
patch('trustgraph.base.async_processor.ProcessorMetrics') as mock_pm, \
patch('trustgraph.base.async_processor.ConsumerMetrics') as mock_cm:
mock_pubsub.return_value = MagicMock()
mock_consumer.return_value = MagicMock()
mock_pm.return_value = MagicMock()
mock_cm.return_value = MagicMock()
from trustgraph.base.async_processor import AsyncProcessor
p = AsyncProcessor(
id="test-processor",
taskgroup=AsyncMock(),
)
return p
class TestRegisterConfigHandler:
def test_register_without_types(self, processor):
handler = AsyncMock()
processor.register_config_handler(handler)
assert len(processor.config_handlers) == 1
assert processor.config_handlers[0]["handler"] is handler
assert processor.config_handlers[0]["types"] is None
def test_register_with_types(self, processor):
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
assert processor.config_handlers[0]["types"] == {"prompt"}
def test_register_multiple_types(self, processor):
handler = AsyncMock()
processor.register_config_handler(
handler, types=["schema", "collection"]
)
assert processor.config_handlers[0]["types"] == {
"schema", "collection"
}
def test_register_multiple_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
assert len(processor.config_handlers) == 2
class TestOnConfigNotify:
@pytest.mark.asyncio
async def test_skip_old_version(self, processor):
processor.config_version = 5
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=3, types=["prompt"])
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@pytest.mark.asyncio
async def test_skip_same_version(self, processor):
processor.config_version = 5
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=5, types=["prompt"])
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@pytest.mark.asyncio
async def test_skip_irrelevant_types(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=2, types=["schema"])
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
# Version should still be updated
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_fetch_on_relevant_type(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
# Mock fetch_config
mock_config = {"prompt": {"key": "value"}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_handler_without_types_always_called(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler) # No types = all
mock_config = {"anything": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["whatever"])
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
@pytest.mark.asyncio
async def test_mixed_handlers_type_filtering(self, processor):
processor.config_version = 1
prompt_handler = AsyncMock()
schema_handler = AsyncMock()
all_handler = AsyncMock()
processor.register_config_handler(prompt_handler, types=["prompt"])
processor.register_config_handler(schema_handler, types=["schema"])
processor.register_config_handler(all_handler)
mock_config = {"prompt": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
await processor.on_config_notify(msg, None, None)
prompt_handler.assert_called_once()
schema_handler.assert_not_called()
all_handler.assert_called_once()
@pytest.mark.asyncio
async def test_empty_types_invokes_all(self, processor):
"""Empty types list (startup signal) should invoke all handlers."""
processor.config_version = 1
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
mock_config = {}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=[])
await processor.on_config_notify(msg, None, None)
h1.assert_called_once()
h2.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_failure_handled(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler)
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
side_effect=RuntimeError("Connection failed")
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
# Should not raise
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
class TestFetchConfig:
@pytest.mark.asyncio
async def test_fetch_returns_config_and_version(self, processor):
mock_resp = Mock()
mock_resp.error = None
mock_resp.config = {"prompt": {"key": "val"}}
mock_resp.version = 42
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
config, version = await processor.fetch_config()
assert config == {"prompt": {"key": "val"}}
assert version == 42
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_raises_on_error_response(self, processor):
mock_resp = Mock()
mock_resp.error = Mock(message="not found")
mock_resp.config = {}
mock_resp.version = 0
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(RuntimeError, match="Config error"):
await processor.fetch_config()
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_stops_client_on_exception(self, processor):
mock_client = AsyncMock()
mock_client.request.side_effect = TimeoutError("timeout")
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(TimeoutError):
await processor.fetch_config()
mock_client.stop.assert_called_once()
class TestFetchAndApplyConfig:
@pytest.mark.asyncio
async def test_applies_config_to_all_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
mock_config = {"prompt": {}, "schema": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 10)
):
await processor.fetch_and_apply_config()
# On startup, all handlers are invoked regardless of type
h1.assert_called_once_with(mock_config, 10)
h2.assert_called_once_with(mock_config, 10)
assert processor.config_version == 10
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
call_count = 0
mock_config = {"prompt": {}}
async def mock_fetch():
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("not ready")
return mock_config, 5
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
patch('asyncio.sleep', new_callable=AsyncMock):
await processor.fetch_and_apply_config()
assert call_count == 3
assert processor.config_version == 5

View file

@ -35,7 +35,9 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_async_init.assert_called_once()
# Verify register_config_handler was called with the correct handler
mock_register_config.assert_called_once_with(processor.on_configure_flows)
mock_register_config.assert_called_once_with(
processor.on_configure_flows, types=["active-flow"]
)
# Verify FlowProcessor-specific initialization
assert hasattr(processor, 'flows')

View file

@ -61,23 +61,21 @@ async def test_subscriber_deferred_acknowledgment_success():
max_size=10,
backpressure_strategy="block"
)
# Start subscriber to initialize consumer
await subscriber.start()
subscriber.consumer = mock_consumer
# Create queue for subscription
queue = await subscriber.subscribe("test-queue")
# Create mock message with matching queue name
msg = create_mock_message("test-queue", {"data": "test"})
# Process message
await subscriber._process_message(msg)
# Should acknowledge successful delivery
mock_consumer.acknowledge.assert_called_once_with(msg)
mock_consumer.negative_acknowledge.assert_not_called()
# Message should be in queue
assert not queue.empty()
received_msg = await queue.get()
@ -108,9 +106,7 @@ async def test_subscriber_dropped_message_still_acks():
max_size=1, # Very small queue
backpressure_strategy="drop_new"
)
# Start subscriber to initialize consumer
await subscriber.start()
subscriber.consumer = mock_consumer
# Create queue and fill it
queue = await subscriber.subscribe("test-queue")
@ -151,9 +147,7 @@ async def test_subscriber_orphaned_message_acks():
max_size=10,
backpressure_strategy="block"
)
# Start subscriber to initialize consumer
await subscriber.start()
subscriber.consumer = mock_consumer
# Don't create any queues - message will be orphaned
# This simulates a response arriving after the waiter has unsubscribed
@ -189,9 +183,7 @@ async def test_subscriber_backpressure_strategies():
max_size=2,
backpressure_strategy="drop_oldest"
)
# Start subscriber to initialize consumer
await subscriber.start()
subscriber.consumer = mock_consumer
queue = await subscriber.subscribe("test-queue")

View file

@ -24,8 +24,8 @@ class MockAsyncProcessor:
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
"""Test Recursive chunker functionality"""
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
"""Test basic processor initialization"""
@ -51,8 +51,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-size parameter override"""
@ -71,7 +71,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
@ -85,8 +85,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 2000 # Should use overridden value
assert chunk_overlap == 100 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-overlap parameter override"""
@ -105,7 +105,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap
}.get(param)
@ -119,8 +119,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 1000 # Should use default value
assert chunk_overlap == 200 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
@ -139,7 +139,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap
}.get(param)
@ -153,8 +153,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 1500 # Should use overridden value
assert chunk_overlap == 150 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
@ -177,7 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
# Mock message with TextDocument
mock_message = MagicMock()
@ -196,12 +196,14 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 1500,
"chunk-overlap": 150,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(param)
}.get(name)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -219,8 +221,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
@ -239,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.return_value = None # No overrides
mock_flow.parameters.get.return_value = None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -24,8 +24,8 @@ class MockAsyncProcessor:
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
"""Test Token chunker functionality"""
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
"""Test basic processor initialization"""
@ -51,8 +51,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-size parameter override"""
@ -71,7 +71,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
@ -85,8 +85,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 400 # Should use overridden value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-overlap parameter override"""
@ -105,7 +105,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap
}.get(param)
@ -119,8 +119,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 25 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
@ -139,7 +139,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap
}.get(param)
@ -153,8 +153,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 350 # Should use overridden value
assert chunk_overlap == 30 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
@ -177,7 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid librarian producer interactions
processor.save_child_document = AsyncMock(return_value="chunk-id")
processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
# Mock message with TextDocument
mock_message = MagicMock()
@ -196,12 +196,14 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 400,
"chunk-overlap": 40,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(param)
}.get(name)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -223,8 +225,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
@ -243,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.return_value = None # No overrides
mock_flow.parameters.get.return_value = None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -254,8 +256,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
"""Test that token chunker has different defaults than recursive chunker"""

View file

@ -21,17 +21,15 @@ class TestSyncDocumentEmbeddingsClient:
# Act
client = DocumentEmbeddingsClient(
log_level=1,
subscriber="test-subscriber",
input_queue="test-input",
output_queue="test-output",
pulsar_host="pulsar://test:6650",
pulsar_api_key="test-key"
)
# Assert
mock_base_init.assert_called_once_with(
log_level=1,
subscriber="test-subscriber",
input_queue="test-input",
output_queue="test-output",

View file

@ -81,9 +81,8 @@ class TestTaskGroupConcurrency:
# Track how many consume_from_queue calls are made
call_count = 0
original_running = True
async def mock_consume():
async def mock_consume(backend_consumer, executor=None):
nonlocal call_count
call_count += 1
# Wait a bit to let all tasks start, then signal stop
@ -107,7 +106,7 @@ class TestTaskGroupConcurrency:
consumer = _make_consumer(concurrency=1)
call_count = 0
async def mock_consume():
async def mock_consume(backend_consumer, executor=None):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
@ -147,7 +146,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
assert call_count == 2
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@ -166,7 +165,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
consumer.consumer.acknowledge.assert_not_called()
@ -185,7 +184,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
assert call_count == 1
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
@ -197,7 +196,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@ -219,7 +218,7 @@ class TestMetricsIntegration:
mock_metrics.record_time.return_value.__exit__ = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.process.assert_called_once_with("success")
@ -235,7 +234,7 @@ class TestMetricsIntegration:
mock_metrics = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.process.assert_called_once_with("error")
@ -261,7 +260,7 @@ class TestMetricsIntegration:
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.rate_limit.assert_called_once()
@ -294,9 +293,8 @@ class TestPollTimeout:
raise type('Timeout', (Exception,), {})("timeout")
mock_pulsar_consumer.receive = capture_receive
consumer.consumer = mock_pulsar_consumer
await consumer.consume_from_queue()
await consumer.consume_from_queue(mock_pulsar_consumer)
assert received_kwargs.get("timeout_millis") == 100

View file

@ -25,8 +25,8 @@ class MockAsyncProcessor:
class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
"""Test Mistral OCR processor functionality"""
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization_with_api_key(
@ -51,8 +51,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization_without_api_key(
self, mock_producer, mock_consumer
@ -66,8 +66,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
Processor(**config)
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_ocr_single_chunk(
@ -131,8 +131,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
)
mock_mistral.ocr.process.assert_called_once()
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_success(
@ -172,7 +172,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
]
# Mock save_child_document
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
with patch.object(processor, 'ocr', return_value=ocr_result):
await processor.on_message(mock_msg, None, mock_flow)

View file

@ -24,12 +24,10 @@ class MockAsyncProcessor:
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"""Test PDF decoder processor functionality"""
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
async def test_processor_initialization(self, mock_producer, mock_consumer):
"""Test PDF decoder processor initialization"""
config = {
'id': 'test-pdf-decoder',
@ -44,13 +42,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing"""
# Mock PDF content
pdf_content = b"fake pdf content"
@ -85,7 +81,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow)
@ -94,13 +90,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF"""
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
@ -128,13 +122,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
mock_output_flow.send.assert_not_called()
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF"""
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
@ -165,7 +157,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow)

View file

@ -142,8 +142,8 @@ class TestPageBasedFormats:
class TestUniversalProcessor(IsolatedAsyncioTestCase):
"""Test universal decoder processor."""
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization(
self, mock_producer, mock_consumer
@ -169,8 +169,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_custom_strategy(
self, mock_producer, mock_consumer
@ -188,8 +188,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert processor.partition_strategy == "hi_res"
assert processor.section_strategy_name == "heading"
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_group_by_page(self, mock_producer, mock_consumer):
"""Test page grouping of elements."""
@ -214,8 +214,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert result[1][0] == 2
assert len(result[1][1]) == 1
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_inline_non_page(
@ -255,7 +255,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
}.get(name))
# Mock save_child_document and magic
processor.save_child_document = AsyncMock(return_value="mock-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "text/markdown"
@ -271,8 +271,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert call_args.document_id.startswith("urn:section:")
assert call_args.text == b""
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_page_based(
@ -310,7 +310,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow,
}.get(name))
processor.save_child_document = AsyncMock(return_value="mock-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf"
@ -323,8 +323,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
call_args = mock_output_flow.send.call_args_list[0][0][0]
assert call_args.document_id.startswith("urn:page:")
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_images_stored_not_emitted(
@ -361,7 +361,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow,
}.get(name))
processor.save_child_document = AsyncMock(return_value="mock-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf"
@ -374,7 +374,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert mock_triples_flow.send.call_count == 2
# save_child_document called twice (page + image)
assert processor.save_child_document.call_count == 2
assert processor.librarian.save_child_document.call_count == 2
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
def test_add_args(self, mock_parent_add_args):

View file

@ -5,7 +5,7 @@ Tests for Gateway Config Receiver
import pytest
import asyncio
import json
from unittest.mock import Mock, patch, Mock, MagicMock
from unittest.mock import Mock, patch, MagicMock, AsyncMock
import uuid
from trustgraph.gateway.config.receiver import ConfigReceiver
@ -23,174 +23,237 @@ class TestConfigReceiver:
def test_config_receiver_initialization(self):
"""Test ConfigReceiver initialization"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
assert config_receiver.backend == mock_backend
assert config_receiver.flow_handlers == []
assert config_receiver.flows == {}
assert config_receiver.config_version == 0
def test_add_handler(self):
"""Test adding flow handlers"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler2 = Mock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
assert len(config_receiver.flow_handlers) == 2
assert handler1 in config_receiver.flow_handlers
assert handler2 in config_receiver.flow_handlers
@pytest.mark.asyncio
async def test_on_config_with_new_flows(self):
"""Test on_config method with new flows"""
async def test_on_config_notify_new_version(self):
"""Test on_config_notify triggers fetch for newer version"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Track calls manually instead of using AsyncMock
start_flow_calls = []
async def mock_start_flow(*args):
start_flow_calls.append(args)
config_receiver.start_flow = mock_start_flow
# Create mock message with flows
config_receiver.config_version = 1
# Mock fetch_and_apply
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with newer version
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={
"flow": {
"flow1": '{"name": "test_flow_1", "steps": []}',
"flow2": '{"name": "test_flow_2", "steps": []}'
}
}
)
await config_receiver.on_config(mock_msg, None, None)
# Verify flows were added
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert config_receiver.flows["flow1"] == {"name": "test_flow_1", "steps": []}
assert config_receiver.flows["flow2"] == {"name": "test_flow_2", "steps": []}
# Verify start_flow was called for each new flow
assert len(start_flow_calls) == 2
assert ("flow1", {"name": "test_flow_1", "steps": []}) in start_flow_calls
assert ("flow2", {"name": "test_flow_2", "steps": []}) in start_flow_calls
mock_msg.value.return_value = Mock(version=2, types=["flow"])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1
@pytest.mark.asyncio
async def test_on_config_with_removed_flows(self):
"""Test on_config method with removed flows"""
async def test_on_config_notify_old_version_ignored(self):
"""Test on_config_notify ignores older versions"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1", "steps": []},
"flow2": {"name": "test_flow_2", "steps": []}
}
# Track calls manually instead of using AsyncMock
stop_flow_calls = []
async def mock_stop_flow(*args):
stop_flow_calls.append(args)
config_receiver.stop_flow = mock_stop_flow
# Create mock message with only flow1 (flow2 removed)
config_receiver.config_version = 5
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with older version
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={
"flow": {
"flow1": '{"name": "test_flow_1", "steps": []}'
}
}
)
await config_receiver.on_config(mock_msg, None, None)
# Verify flow2 was removed
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
# Verify stop_flow was called for removed flow
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0] == ("flow2", {"name": "test_flow_2", "steps": []})
mock_msg.value.return_value = Mock(version=3, types=["flow"])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 0
@pytest.mark.asyncio
async def test_on_config_with_no_flows(self):
"""Test on_config method with no flows in config"""
async def test_on_config_notify_irrelevant_types_ignored(self):
"""Test on_config_notify ignores types the gateway doesn't care about"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock the start_flow and stop_flow methods with async functions
async def mock_start_flow(*args):
pass
async def mock_stop_flow(*args):
pass
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
# Create mock message without flows
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with non-flow type
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={}
)
await config_receiver.on_config(mock_msg, None, None)
# Verify no flows were added
assert config_receiver.flows == {}
# Since no flows were in the config, the flow methods shouldn't be called
# (We can't easily assert this with simple async functions, but the test
# passes if no exceptions are thrown)
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
await config_receiver.on_config_notify(mock_msg, None, None)
# Version should be updated but no fetch
assert len(fetch_calls) == 0
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_exception_handling(self):
"""Test on_config method handles exceptions gracefully"""
async def test_on_config_notify_flow_type_triggers_fetch(self):
"""Test on_config_notify fetches for flow-related types"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create mock message that will cause an exception
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
for type_name in ["flow", "active-flow"]:
fetch_calls.clear()
config_receiver.config_version = 1
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=[type_name])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
@pytest.mark.asyncio
async def test_on_config_notify_exception_handling(self):
"""Test on_config_notify handles exceptions gracefully"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create notify message that causes an exception
mock_msg = Mock()
mock_msg.value.side_effect = Exception("Test exception")
# This should not raise an exception
await config_receiver.on_config(mock_msg, None, None)
# Verify flows remain empty
# Should not raise
await config_receiver.on_config_notify(mock_msg, None, None)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_new_flows(self):
"""Test fetch_and_apply starts new flows"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock _create_config_client to return a mock client
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}',
"flow2": '{"name": "test_flow_2"}'
}
}
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
start_flow_calls = []
async def mock_start_flow(id, flow):
start_flow_calls.append((id, flow))
config_receiver.start_flow = mock_start_flow
await config_receiver.fetch_and_apply()
assert config_receiver.config_version == 5
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert len(start_flow_calls) == 2
@pytest.mark.asyncio
async def test_fetch_and_apply_with_removed_flows(self):
"""Test fetch_and_apply stops removed flows"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
}
# Config now only has flow1
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}'
}
}
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
stop_flow_calls = []
async def mock_stop_flow(id, flow):
stop_flow_calls.append((id, flow))
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0][0] == "flow2"
@pytest.mark.asyncio
async def test_fetch_and_apply_with_no_flows(self):
"""Test fetch_and_apply with empty config"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 1
mock_resp.config = {}
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
await config_receiver.fetch_and_apply()
assert config_receiver.flows == {}
assert config_receiver.config_version == 1
@pytest.mark.asyncio
async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Add mock handlers
handler1 = Mock()
handler1.start_flow = Mock()
handler2 = Mock()
handler2.start_flow = Mock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.start_flow("flow1", flow_data)
# Verify all handlers were called
handler1.start_flow.assert_called_once_with("flow1", flow_data)
handler2.start_flow.assert_called_once_with("flow1", flow_data)
@ -199,19 +262,17 @@ class TestConfigReceiver:
"""Test start_flow method handles handler exceptions"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Add mock handler that raises exception
handler = Mock()
handler.start_flow = Mock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# This should not raise an exception
# Should not raise
await config_receiver.start_flow("flow1", flow_data)
# Verify handler was called
handler.start_flow.assert_called_once_with("flow1", flow_data)
@pytest.mark.asyncio
@ -219,21 +280,19 @@ class TestConfigReceiver:
"""Test stop_flow method with multiple handlers"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Add mock handlers
handler1 = Mock()
handler1.stop_flow = Mock()
handler2 = Mock()
handler2.stop_flow = Mock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.stop_flow("flow1", flow_data)
# Verify all handlers were called
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
@ -242,167 +301,77 @@ class TestConfigReceiver:
"""Test stop_flow method handles handler exceptions"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Add mock handler that raises exception
handler = Mock()
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# This should not raise an exception
# Should not raise
await config_receiver.stop_flow("flow1", flow_data)
# Verify handler was called
handler.stop_flow.assert_called_once_with("flow1", flow_data)
@pytest.mark.asyncio
async def test_config_loader_creates_consumer(self):
"""Test config_loader method creates Pulsar consumer"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Temporarily restore the real config_loader for this test
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
# Mock Consumer class
with patch('trustgraph.gateway.config.receiver.Consumer') as mock_consumer_class, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
mock_consumer = Mock()
async def mock_start():
pass
mock_consumer.start = mock_start
mock_consumer_class.return_value = mock_consumer
# Create a task that will complete quickly
async def quick_task():
await config_receiver.config_loader()
# Run the task with a timeout to prevent hanging
try:
await asyncio.wait_for(quick_task(), timeout=0.1)
except asyncio.TimeoutError:
# This is expected since the method runs indefinitely
pass
# Verify Consumer was created with correct parameters
mock_consumer_class.assert_called_once()
call_args = mock_consumer_class.call_args
assert call_args[1]['backend'] == mock_backend
assert call_args[1]['subscriber'] == "gateway-test-uuid"
assert call_args[1]['handler'] == config_receiver.on_config
assert call_args[1]['start_of_messages'] is True
@patch('asyncio.create_task')
@pytest.mark.asyncio
async def test_start_creates_config_loader_task(self, mock_create_task):
"""Test start method creates config loader task"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock create_task to avoid actually creating tasks with real coroutines
mock_task = Mock()
mock_create_task.return_value = mock_task
await config_receiver.start()
# Verify task was created
mock_create_task.assert_called_once()
# Verify the argument passed to create_task is a coroutine
call_args = mock_create_task.call_args[0]
assert len(call_args) == 1 # Should have one argument (the coroutine)
@pytest.mark.asyncio
async def test_on_config_mixed_flow_operations(self):
"""Test on_config with mixed add/remove operations"""
async def test_fetch_and_apply_mixed_flow_operations(self):
"""Test fetch_and_apply with mixed add/remove operations"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
# Pre-populate
config_receiver.flows = {
"flow1": {"name": "test_flow_1", "steps": []},
"flow2": {"name": "test_flow_2", "steps": []}
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
}
# Track calls manually instead of using Mock
start_flow_calls = []
stop_flow_calls = []
async def mock_start_flow(*args):
start_flow_calls.append(args)
async def mock_stop_flow(*args):
stop_flow_calls.append(args)
# Directly assign to avoid patch.object detecting async methods
original_start_flow = config_receiver.start_flow
original_stop_flow = config_receiver.stop_flow
# Config removes flow1, keeps flow2, adds flow3
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow2": '{"name": "test_flow_2"}',
"flow3": '{"name": "test_flow_3"}'
}
}
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
config_receiver._create_config_client = Mock(return_value=mock_client)
start_calls = []
stop_calls = []
async def mock_start_flow(id, flow):
start_calls.append((id, flow))
async def mock_stop_flow(id, flow):
stop_calls.append((id, flow))
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
try:
# Create mock message with flow1 removed and flow3 added
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={
"flow": {
"flow2": '{"name": "test_flow_2", "steps": []}',
"flow3": '{"name": "test_flow_3", "steps": []}'
}
}
)
await config_receiver.on_config(mock_msg, None, None)
# Verify final state
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
# Verify operations
assert len(start_flow_calls) == 1
assert start_flow_calls[0] == ("flow3", {"name": "test_flow_3", "steps": []})
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0] == ("flow1", {"name": "test_flow_1", "steps": []})
finally:
# Restore original methods
config_receiver.start_flow = original_start_flow
config_receiver.stop_flow = original_stop_flow
@pytest.mark.asyncio
async def test_on_config_invalid_json_flow_data(self):
"""Test on_config handles invalid JSON in flow data"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock the start_flow method with an async function
async def mock_start_flow(*args):
pass
config_receiver.start_flow = mock_start_flow
# Create mock message with invalid JSON
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={
"flow": {
"flow1": '{"invalid": json}', # Invalid JSON
"flow2": '{"name": "valid_flow", "steps": []}' # Valid JSON
}
}
)
# This should handle the exception gracefully
await config_receiver.on_config(mock_msg, None, None)
# The entire operation should fail due to JSON parsing error
# So no flows should be added
assert config_receiver.flows == {}
await config_receiver.fetch_and_apply()
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
assert len(start_calls) == 1
assert start_calls[0][0] == "flow3"
assert len(stop_calls) == 1
assert stop_calls[0][0] == "flow1"

View file

@ -49,7 +49,7 @@ class TestConfigRequestor:
mock_translator_registry.get_response_translator.return_value = Mock()
# Setup translator response
mock_request_translator.to_pulsar.return_value = "translated_request"
mock_request_translator.decode.return_value = "translated_request"
# Patch ServiceRequestor async methods with regular mocks (not AsyncMock)
with patch.object(ServiceRequestor, 'start', return_value=None), \
@ -64,7 +64,7 @@ class TestConfigRequestor:
result = requestor.to_request({"test": "body"})
# Verify translator was called correctly
mock_request_translator.to_pulsar.assert_called_once_with({"test": "body"})
mock_request_translator.decode.assert_called_once_with({"test": "body"})
assert result == "translated_request"
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
@ -76,7 +76,7 @@ class TestConfigRequestor:
mock_translator_registry.get_response_translator.return_value = mock_response_translator
# Setup translator response
mock_response_translator.from_response_with_completion.return_value = "translated_response"
mock_response_translator.encode_with_completion.return_value = "translated_response"
requestor = ConfigRequestor(
backend=Mock(),
@ -89,5 +89,5 @@ class TestConfigRequestor:
result = requestor.from_response(mock_message)
# Verify translator was called correctly
mock_response_translator.from_response_with_completion.assert_called_once_with(mock_message)
mock_response_translator.encode_with_completion.assert_called_once_with(mock_message)
assert result == "translated_response"

View file

@ -0,0 +1,359 @@
"""
Tests for inline explainability triples in response translators
and ProvenanceEvent parsing.
"""
import pytest
from trustgraph.schema import (
GraphRagResponse, DocumentRagResponse, AgentResponse,
Term, Triple, IRI, LITERAL, Error,
)
from trustgraph.messaging.translators.retrieval import (
GraphRagResponseTranslator,
DocumentRagResponseTranslator,
)
from trustgraph.messaging.translators.agent import (
AgentResponseTranslator,
)
from trustgraph.api.types import ProvenanceEvent
# --- Helpers ---
def make_triple(s_iri, p_iri, o_value, o_type=LITERAL):
"""Create a Triple with IRI subject/predicate and typed object."""
o = Term(type=IRI, iri=o_value) if o_type == IRI else Term(type=LITERAL, value=o_value)
return Triple(
s=Term(type=IRI, iri=s_iri),
p=Term(type=IRI, iri=p_iri),
o=o,
)
def sample_triples():
"""A few provenance triples for a question entity."""
return [
make_triple(
"urn:trustgraph:question:abc123",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"https://trustgraph.ai/ns/GraphRagQuestion",
o_type=IRI,
),
make_triple(
"urn:trustgraph:question:abc123",
"https://trustgraph.ai/ns/query",
"What is the internet?",
),
make_triple(
"urn:trustgraph:question:abc123",
"http://www.w3.org/ns/prov#startedAtTime",
"2026-04-07T09:00:00Z",
),
]
# --- GraphRag Translator ---
class TestGraphRagExplainTriples:
def test_explain_triples_encoded(self):
translator = GraphRagResponseTranslator()
triples = sample_triples()
response = GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc123",
explain_graph="urn:graph:retrieval",
explain_triples=triples,
)
result = translator.encode(response)
assert "explain_triples" in result
assert len(result["explain_triples"]) == 3
# Check first triple is properly encoded
t = result["explain_triples"][0]
assert t["s"]["t"] == "i"
assert t["s"]["i"] == "urn:trustgraph:question:abc123"
assert t["p"]["t"] == "i"
def test_explain_triples_empty_not_included(self):
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
message_type="chunk",
response="Some answer text",
)
result = translator.encode(response)
assert "explain_triples" not in result
def test_explain_with_completion_returns_not_final(self):
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc123",
explain_triples=sample_triples(),
end_of_session=False,
)
result, is_final = translator.encode_with_completion(response)
assert is_final is False
def test_explain_id_and_graph_included(self):
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc123",
explain_graph="urn:graph:retrieval",
explain_triples=sample_triples(),
)
result = translator.encode(response)
assert result["explain_id"] == "urn:trustgraph:question:abc123"
assert result["explain_graph"] == "urn:graph:retrieval"
# --- DocumentRag Translator ---
class TestDocumentRagExplainTriples:
def test_explain_triples_encoded(self):
translator = DocumentRagResponseTranslator()
response = DocumentRagResponse(
response=None,
message_type="explain",
explain_id="urn:trustgraph:docrag:abc123",
explain_graph="urn:graph:retrieval",
explain_triples=sample_triples(),
)
result = translator.encode(response)
assert "explain_triples" in result
assert len(result["explain_triples"]) == 3
def test_explain_triples_empty_not_included(self):
translator = DocumentRagResponseTranslator()
response = DocumentRagResponse(
response="Answer text",
message_type="chunk",
)
result = translator.encode(response)
assert "explain_triples" not in result
# --- Agent Translator ---
class TestAgentExplainTriples:
def test_explain_triples_encoded(self):
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
content="",
explain_id="urn:trustgraph:agent:session:abc123",
explain_graph="urn:graph:retrieval",
explain_triples=sample_triples(),
)
result = translator.encode(response)
assert "explain_triples" in result
assert len(result["explain_triples"]) == 3
t = result["explain_triples"][1]
assert t["p"]["i"] == "https://trustgraph.ai/ns/query"
assert t["o"]["t"] == "l"
assert t["o"]["v"] == "What is the internet?"
def test_explain_triples_empty_not_included(self):
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="thought",
content="I need to think...",
)
result = translator.encode(response)
assert "explain_triples" not in result
def test_explain_with_completion_not_final(self):
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
explain_id="urn:trustgraph:agent:session:abc123",
explain_triples=sample_triples(),
end_of_dialog=False,
)
result, is_final = translator.encode_with_completion(response)
assert is_final is False
def test_explain_with_completion_final(self):
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="answer",
content="The answer is...",
end_of_dialog=True,
)
result, is_final = translator.encode_with_completion(response)
assert is_final is True
# --- ProvenanceEvent ---
class TestProvenanceEvent:
def test_question_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:question:abc123",
)
assert event.event_type == "question"
def test_exploration_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:exploration:abc123",
)
assert event.event_type == "exploration"
def test_focus_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:focus:abc123",
)
assert event.event_type == "focus"
def test_synthesis_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:synthesis:abc123",
)
assert event.event_type == "synthesis"
def test_grounding_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:grounding:abc123",
)
assert event.event_type == "grounding"
def test_session_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:session:abc123",
)
assert event.event_type == "session"
def test_iteration_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:iteration:abc123:1",
)
assert event.event_type == "iteration"
def test_observation_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:observation:abc123:1",
)
assert event.event_type == "observation"
def test_conclusion_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:conclusion:abc123",
)
assert event.event_type == "conclusion"
def test_decomposition_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:decomposition:abc123",
)
assert event.event_type == "decomposition"
def test_finding_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:finding:abc123:0",
)
assert event.event_type == "finding"
def test_plan_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:plan:abc123",
)
assert event.event_type == "plan"
def test_step_result_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:step-result:abc123:0",
)
assert event.event_type == "step-result"
def test_defaults(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:question:abc123",
)
assert event.entity is None
assert event.triples == []
assert event.explain_graph == ""
def test_with_triples(self):
raw = [{"s": {"t": "i", "i": "urn:x"}, "p": {"t": "i", "i": "urn:y"}, "o": {"t": "l", "v": "z"}}]
event = ProvenanceEvent(
explain_id="urn:trustgraph:question:abc123",
triples=raw,
)
assert len(event.triples) == 1
# --- Build ProvenanceEvent with entity parsing ---
class TestBuildProvenanceEvent:
def _make_client(self):
"""Create a minimal WebSocketClient-like object with _build_provenance_event."""
from trustgraph.api.socket_client import WebSocketClient
# We can't instantiate WebSocketClient easily, so test the method logic directly
return None
def test_entity_parsed_from_wire_triples(self):
"""Test that wire-format triples are parsed into an ExplainEntity."""
from trustgraph.api.explainability import ExplainEntity
wire_triples = [
{
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
"p": {"t": "i", "i": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"},
"o": {"t": "i", "i": "https://trustgraph.ai/ns/GraphRagQuestion"},
},
{
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
"p": {"t": "i", "i": "https://trustgraph.ai/ns/query"},
"o": {"t": "l", "v": "What is the internet?"},
},
]
# Parse triples the same way _build_provenance_event does
parsed = []
for t in wire_triples:
s = t.get("s", {}).get("i", "")
p = t.get("p", {}).get("i", "")
o_term = t.get("o", {})
if o_term.get("t") == "i":
o = o_term.get("i", "")
else:
o = o_term.get("v", "")
parsed.append((s, p, o))
entity = ExplainEntity.from_triples(
"urn:trustgraph:question:abc123", parsed
)
assert entity.entity_type == "question"
assert entity.query == "What is the internet?"
assert entity.question_type == "graph-rag"

View file

@ -25,7 +25,7 @@ from trustgraph.schema import (
class TestGraphRagResponseTranslator:
"""Test GraphRagResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_response(self):
def test_encode_with_empty_response(self):
"""Test that empty response strings are preserved"""
# Arrange
translator = GraphRagResponseTranslator()
@ -36,14 +36,14 @@ class TestGraphRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert - Empty string should be included in result
assert "response" in result
assert result["response"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_response(self):
def test_encode_with_non_empty_response(self):
"""Test that non-empty responses work correctly"""
# Arrange
translator = GraphRagResponseTranslator()
@ -54,13 +54,13 @@ class TestGraphRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert result["response"] == "Some text"
assert result["end_of_stream"] is False
def test_from_pulsar_with_none_response(self):
def test_encode_with_none_response(self):
"""Test that None response is handled correctly"""
# Arrange
translator = GraphRagResponseTranslator()
@ -71,14 +71,14 @@ class TestGraphRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert - None should not be included
assert "response" not in result
assert result["end_of_stream"] is True
def test_from_response_with_completion_returns_correct_flag(self):
"""Test that from_response_with_completion returns correct is_final flag"""
def test_encode_with_completion_returns_correct_flag(self):
"""Test that encode_with_completion returns correct is_final flag"""
# Arrange
translator = GraphRagResponseTranslator()
@ -90,7 +90,7 @@ class TestGraphRagResponseTranslator:
)
# Act
result, is_final = translator.from_response_with_completion(response_chunk)
result, is_final = translator.encode_with_completion(response_chunk)
# Assert
assert is_final is False
@ -105,7 +105,7 @@ class TestGraphRagResponseTranslator:
)
# Act
result, is_final = translator.from_response_with_completion(final_response)
result, is_final = translator.encode_with_completion(final_response)
# Assert - is_final is based on end_of_session, not end_of_stream
assert is_final is True
@ -116,7 +116,7 @@ class TestGraphRagResponseTranslator:
class TestDocumentRagResponseTranslator:
"""Test DocumentRagResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_response(self):
def test_encode_with_empty_response(self):
"""Test that empty response strings are preserved"""
# Arrange
translator = DocumentRagResponseTranslator()
@ -127,14 +127,14 @@ class TestDocumentRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "response" in result
assert result["response"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_response(self):
def test_encode_with_non_empty_response(self):
"""Test that non-empty responses work correctly"""
# Arrange
translator = DocumentRagResponseTranslator()
@ -145,7 +145,7 @@ class TestDocumentRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert result["response"] == "Document content"
@ -155,7 +155,7 @@ class TestDocumentRagResponseTranslator:
class TestPromptResponseTranslator:
"""Test PromptResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_text(self):
def test_encode_with_empty_text(self):
"""Test that empty text strings are preserved"""
# Arrange
translator = PromptResponseTranslator()
@ -167,14 +167,14 @@ class TestPromptResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "text" in result
assert result["text"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_text(self):
def test_encode_with_non_empty_text(self):
"""Test that non-empty text works correctly"""
# Arrange
translator = PromptResponseTranslator()
@ -186,13 +186,13 @@ class TestPromptResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert result["text"] == "Some prompt response"
assert result["end_of_stream"] is False
def test_from_pulsar_with_none_text(self):
def test_encode_with_none_text(self):
"""Test that None text is handled correctly"""
# Arrange
translator = PromptResponseTranslator()
@ -204,14 +204,14 @@ class TestPromptResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "text" not in result
assert "object" in result
assert result["end_of_stream"] is True
def test_from_pulsar_includes_end_of_stream(self):
def test_encode_includes_end_of_stream(self):
"""Test that end_of_stream flag is always included"""
# Arrange
translator = PromptResponseTranslator()
@ -225,7 +225,7 @@ class TestPromptResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "end_of_stream" in result
@ -235,7 +235,7 @@ class TestPromptResponseTranslator:
class TestTextCompletionResponseTranslator:
"""Test TextCompletionResponseTranslator streaming behavior"""
def test_from_pulsar_always_includes_response(self):
def test_encode_always_includes_response(self):
"""Test that response field is always included, even if empty"""
# Arrange
translator = TextCompletionResponseTranslator()
@ -249,13 +249,13 @@ class TestTextCompletionResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert - Response should always be present
assert "response" in result
assert result["response"] == ""
def test_from_response_with_completion_with_empty_final(self):
def test_encode_with_completion_with_empty_final(self):
"""Test that empty final response is handled correctly"""
# Arrange
translator = TextCompletionResponseTranslator()
@ -269,7 +269,7 @@ class TestTextCompletionResponseTranslator:
)
# Act
result, is_final = translator.from_response_with_completion(response)
result, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True
@ -297,7 +297,7 @@ class TestStreamingProtocolCompliance:
response = response_class(**kwargs)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
@ -320,7 +320,7 @@ class TestStreamingProtocolCompliance:
response = response_class(**kwargs)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"

View file

@ -0,0 +1,54 @@
"""
Unit tests for text document gateway translation compatibility.
"""
import base64
from trustgraph.messaging.translators.document_loading import TextDocumentTranslator
class TestTextDocumentTranslator:
def test_decode_decodes_base64_text(self):
translator = TextDocumentTranslator()
payload = "Cancer survival: 2.74× higher hazard ratio"
msg = translator.decode(
{
"id": "doc-1",
"user": "alice",
"collection": "research",
"charset": "utf-8",
"text": base64.b64encode(payload.encode("utf-8")).decode("ascii"),
}
)
assert msg.metadata.id == "doc-1"
assert msg.metadata.user == "alice"
assert msg.metadata.collection == "research"
assert msg.text == payload.encode("utf-8")
def test_decode_accepts_raw_utf8_text(self):
translator = TextDocumentTranslator()
payload = "Cancer survival: 2.74× higher hazard ratio"
msg = translator.decode(
{
"charset": "utf-8",
"text": payload,
}
)
assert msg.text == payload.encode("utf-8")
def test_decode_falls_back_to_raw_non_base64_ascii(self):
translator = TextDocumentTranslator()
payload = "plain-text payload"
msg = translator.decode(
{
"charset": "utf-8",
"text": payload,
}
)
assert msg.text == payload.encode("utf-8")

View file

@ -10,16 +10,19 @@ from trustgraph.schema import Triple, Term, IRI, LITERAL
from trustgraph.provenance.agent import (
agent_session_triples,
agent_iteration_triples,
agent_observation_triples,
agent_final_triples,
agent_synthesis_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_TOOL_USE, TG_SYNTHESIS,
TG_AGENT_QUESTION,
)
@ -63,7 +66,7 @@ class TestAgentSessionTriples:
triples = agent_session_triples(
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
)
assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY)
assert has_type(triples, self.SESSION_URI, PROV_ENTITY)
assert has_type(triples, self.SESSION_URI, TG_QUESTION)
assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION)
@ -103,6 +106,25 @@ class TestAgentSessionTriples:
)
assert len(triples) == 6
def test_session_parent_uri(self):
"""Subagent sessions derive from a parent entity (e.g. Decomposition)."""
parent = "urn:trustgraph:agent:parent/decompose"
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z",
parent_uri=parent,
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
assert derived is not None
assert derived.o.iri == parent
def test_session_no_parent_uri(self):
"""Top-level sessions have no wasDerivedFrom."""
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
assert derived is None
# ---------------------------------------------------------------------------
# agent_iteration_triples
@ -121,19 +143,17 @@ class TestAgentIterationTriples:
)
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
assert has_type(triples, self.ITER_URI, TG_TOOL_USE)
def test_first_iteration_generated_by_question(self):
"""First iteration uses wasGeneratedBy to link to question activity."""
def test_first_iteration_derived_from_question(self):
"""First iteration uses wasDerivedFrom to link to question entity."""
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
# Should NOT have wasDerivedFrom
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is None
assert derived is not None
assert derived.o.iri == self.SESSION_URI
def test_subsequent_iteration_derived_from_previous(self):
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
@ -144,9 +164,6 @@ class TestAgentIterationTriples:
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is not None
assert derived.o.iri == self.PREV_URI
# Should NOT have wasGeneratedBy
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is None
def test_iteration_label_includes_action(self):
triples = agent_iteration_triples(
@ -174,40 +191,24 @@ class TestAgentIterationTriples:
# Thought has correct types
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
# Thought was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Thought was derived from iteration
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, thought_uri)
assert derived is not None
assert derived.o.iri == self.ITER_URI
# Thought has document reference
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
assert doc is not None
assert doc.o.iri == thought_doc
def test_iteration_observation_sub_entity(self):
"""Observation is a sub-entity with Reflection and Observation types."""
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
obs_doc = "urn:doc:obs-1"
def test_iteration_no_observation_sub_entity(self):
"""Iteration no longer embeds observation — it's a separate entity."""
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
observation_uri=obs_uri,
observation_document_id=obs_doc,
)
# Iteration links to observation sub-entity
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_link is not None
assert obs_link.o.iri == obs_uri
# Observation has correct types
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
# Observation was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Observation has document reference
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
assert doc is not None
assert doc.o.iri == obs_doc
# No TG_OBSERVATION predicate on the iteration
for t in triples:
assert "observation" not in t.p.iri.lower() or "Observation" not in t.p.iri
def test_iteration_action_recorded(self):
triples = agent_iteration_triples(
@ -240,19 +241,17 @@ class TestAgentIterationTriples:
parsed = json.loads(arguments.o.value)
assert parsed == {}
def test_iteration_no_thought_or_observation(self):
"""Minimal iteration with just action — no thought or observation triples."""
def test_iteration_no_thought(self):
"""Minimal iteration with just action — no thought triples."""
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="noop",
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert thought is None
assert obs is None
def test_iteration_chaining(self):
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
"""Both first and second iterations use wasDerivedFrom."""
iter1_uri = "urn:trustgraph:agent:sess/i1"
iter2_uri = "urn:trustgraph:agent:sess/i2"
@ -263,13 +262,62 @@ class TestAgentIterationTriples:
iter2_uri, previous_uri=iter1_uri, action="step2",
)
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
assert gen1.o.iri == self.SESSION_URI
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
assert derived1.o.iri == self.SESSION_URI
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
assert derived2.o.iri == iter1_uri
# ---------------------------------------------------------------------------
# agent_observation_triples
# ---------------------------------------------------------------------------
class TestAgentObservationTriples:
OBS_URI = "urn:trustgraph:agent:test-session/i1/observation"
ITER_URI = "urn:trustgraph:agent:test-session/i1"
def test_observation_types(self):
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI,
)
assert has_type(triples, self.OBS_URI, PROV_ENTITY)
assert has_type(triples, self.OBS_URI, TG_OBSERVATION_TYPE)
def test_observation_derived_from_iteration(self):
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI,
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.OBS_URI)
assert derived is not None
assert derived.o.iri == self.ITER_URI
def test_observation_label(self):
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI,
)
label = find_triple(triples, RDFS_LABEL, self.OBS_URI)
assert label is not None
assert label.o.value == "Observation"
def test_observation_document(self):
doc_id = "urn:doc:obs-1"
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI, document_id=doc_id,
)
doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI)
assert doc is not None
assert doc.o.iri == doc_id
def test_observation_no_document(self):
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI,
)
doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI)
assert doc is None
# ---------------------------------------------------------------------------
# agent_final_triples
# ---------------------------------------------------------------------------
@ -296,19 +344,15 @@ class TestAgentFinalTriples:
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is not None
assert derived.o.iri == self.PREV_URI
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is None
def test_final_generated_by_question_when_no_iterations(self):
"""When agent answers immediately, final uses wasGeneratedBy."""
def test_final_derived_from_question_when_no_iterations(self):
"""When agent answers immediately, final uses wasDerivedFrom to question."""
triples = agent_final_triples(
self.FINAL_URI, question_uri=self.SESSION_URI,
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is None
assert derived is not None
assert derived.o.iri == self.SESSION_URI
def test_final_label(self):
triples = agent_final_triples(
@ -334,3 +378,59 @@ class TestAgentFinalTriples:
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is None
# ---------------------------------------------------------------------------
# agent_synthesis_triples
# ---------------------------------------------------------------------------
class TestAgentSynthesisTriples:
SYNTH_URI = "urn:trustgraph:agent:test-session/synthesis"
FINDING_0 = "urn:trustgraph:agent:test-session/finding/0"
FINDING_1 = "urn:trustgraph:agent:test-session/finding/1"
FINDING_2 = "urn:trustgraph:agent:test-session/finding/2"
def test_synthesis_types(self):
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
assert has_type(triples, self.SYNTH_URI, PROV_ENTITY)
assert has_type(triples, self.SYNTH_URI, TG_SYNTHESIS)
assert has_type(triples, self.SYNTH_URI, TG_ANSWER_TYPE)
def test_synthesis_single_parent_string(self):
"""Single parent passed as string."""
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 1
assert derived[0].o.iri == self.FINDING_0
def test_synthesis_multiple_parents(self):
"""Multiple parents for supervisor fan-in."""
parents = [self.FINDING_0, self.FINDING_1, self.FINDING_2]
triples = agent_synthesis_triples(self.SYNTH_URI, parents)
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 3
derived_uris = {t.o.iri for t in derived}
assert derived_uris == set(parents)
def test_synthesis_single_parent_as_list(self):
"""Single parent passed as list."""
triples = agent_synthesis_triples(self.SYNTH_URI, [self.FINDING_0])
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 1
assert derived[0].o.iri == self.FINDING_0
def test_synthesis_document(self):
triples = agent_synthesis_triples(
self.SYNTH_URI, self.FINDING_0,
document_id="urn:doc:synth",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYNTH_URI)
assert doc is not None
assert doc.o.iri == "urn:doc:synth"
def test_synthesis_label(self):
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
label = find_triple(triples, RDFS_LABEL, self.SYNTH_URI)
assert label is not None
assert label.o.value == "Synthesis"

View file

@ -16,6 +16,7 @@ from trustgraph.api.explainability import (
Synthesis,
Reflection,
Analysis,
Observation,
Conclusion,
parse_edge_selection_triples,
extract_term_value,
@ -23,12 +24,12 @@ from trustgraph.api.explainability import (
ExplainabilityClient,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM,
RDF_TYPE, RDFS_LABEL,
)
@ -180,14 +181,30 @@ class TestExplainEntityFromTriples:
("urn:ana:1", TG_ACTION, "graph-rag-query"),
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:1", triples)
assert isinstance(entity, Analysis)
assert entity.action == "graph-rag-query"
assert entity.arguments == '{"query": "test"}'
assert entity.thought == "urn:ref:thought-1"
assert entity.observation == "urn:ref:obs-1"
def test_observation(self):
triples = [
("urn:obs:1", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:obs:1", TG_DOCUMENT, "urn:doc:obs-content"),
]
entity = ExplainEntity.from_triples("urn:obs:1", triples)
assert isinstance(entity, Observation)
assert entity.document == "urn:doc:obs-content"
assert entity.entity_type == "observation"
def test_observation_no_document(self):
triples = [
("urn:obs:2", RDF_TYPE, TG_OBSERVATION_TYPE),
]
entity = ExplainEntity.from_triples("urn:obs:2", triples)
assert isinstance(entity, Observation)
assert entity.document == ""
def test_conclusion_with_document(self):
triples = [
@ -541,3 +558,96 @@ class TestExplainabilityClientDetectSessionType:
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"
class TestChainWalkerFollowsSubTraceTerminal:
"""Test that _follow_provenance_chain continues from a sub-trace's
Synthesis to find downstream entities like Observation."""
def test_observation_found_via_subtrace_synthesis(self):
"""
DAG: Question -> Analysis -> GraphRAG Question -> Synthesis -> Observation
The walker should find Analysis, the sub-trace, then follow from
Synthesis to discover Observation.
"""
# Entity triples (s, p, o)
entity_data = {
"urn:agent:q": [
("urn:agent:q", RDF_TYPE, TG_AGENT_QUESTION),
("urn:agent:q", TG_QUERY, "test"),
],
"urn:agent:analysis": [
("urn:agent:analysis", RDF_TYPE, TG_ANALYSIS),
("urn:agent:analysis", PROV_WAS_DERIVED_FROM, "urn:agent:q"),
],
"urn:graphrag:q": [
("urn:graphrag:q", RDF_TYPE, TG_QUESTION),
("urn:graphrag:q", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:graphrag:q", TG_QUERY, "test"),
("urn:graphrag:q", PROV_WAS_DERIVED_FROM, "urn:agent:analysis"),
],
"urn:graphrag:synth": [
("urn:graphrag:synth", RDF_TYPE, TG_SYNTHESIS),
("urn:graphrag:synth", PROV_WAS_DERIVED_FROM, "urn:graphrag:q"),
],
"urn:agent:obs": [
("urn:agent:obs", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:agent:obs", PROV_WAS_DERIVED_FROM, "urn:graphrag:synth"),
],
"urn:agent:conclusion": [
("urn:agent:conclusion", RDF_TYPE, TG_CONCLUSION),
("urn:agent:conclusion", PROV_WAS_DERIVED_FROM, "urn:agent:obs"),
],
}
# Build a mock flow that answers triples queries
# Query by s= returns that entity's triples
# Query by p=wasDerivedFrom, o=X returns entities derived from X
def mock_triples_query(s=None, p=None, o=None, **kwargs):
if s and not p:
# Fetch entity triples
tuples = entity_data.get(s, [])
return _make_wire_triples(tuples)
elif p == PROV_WAS_DERIVED_FROM and o:
# Find entities derived from o
results = []
for uri, tuples in entity_data.items():
for _, pred, obj in tuples:
if pred == PROV_WAS_DERIVED_FROM and obj == o:
results.append((uri, pred, obj))
return _make_wire_triples(results)
return []
mock_flow = MagicMock()
mock_flow.triples_query.side_effect = mock_triples_query
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
# Mock fetch_graphrag_trace to return a trace with a synthesis
synth_entity = Synthesis(uri="urn:graphrag:synth", entity_type="synthesis")
client.fetch_graphrag_trace = MagicMock(return_value={
"question": Question(uri="urn:graphrag:q", entity_type="question",
question_type="graph-rag"),
"synthesis": synth_entity,
})
trace = client.fetch_agent_trace(
"urn:agent:q",
graph="urn:graph:retrieval",
)
# Should have found all steps
step_types = [
type(s).__name__ if not isinstance(s, dict) else s.get("type")
for s in trace["steps"]
]
assert "Analysis" in step_types, f"Missing Analysis in {step_types}"
assert "sub-trace" in step_types, f"Missing sub-trace in {step_types}"
assert "Observation" in step_types, f"Missing Observation in {step_types}"
assert "Conclusion" in step_types, f"Missing Conclusion in {step_types}"
# Observation should come after the sub-trace
subtrace_idx = step_types.index("sub-trace")
obs_idx = step_types.index("Observation")
assert obs_idx > subtrace_idx, "Observation should appear after sub-trace"

View file

@ -0,0 +1,295 @@
"""
Structural test for the graph-rag provenance chain.
Verifies that a complete graph-rag query produces the expected
provenance chain:
question grounding exploration focus synthesis
Each step must:
- Have the correct rdf:type
- Link to its predecessor via prov:wasDerivedFrom
- Carry expected domain-specific data
"""
import pytest
from trustgraph.provenance.triples import (
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
)
from trustgraph.provenance.uris import (
question_uri,
grounding_uri,
exploration_uri,
focus_uri,
synthesis_uri,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_GRAPH_RAG_QUESTION, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT,
PROV_STARTED_AT_TIME,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SESSION_ID = "test-session-1234"
def find_triple(triples, predicate, subject=None):
"""Find first triple matching predicate (and optionally subject)."""
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
"""Find all triples matching predicate (and optionally subject)."""
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
"""Check if subject has the given rdf:type."""
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
"""Get the wasDerivedFrom target URI for a subject."""
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
# ---------------------------------------------------------------------------
# Build the full chain
# ---------------------------------------------------------------------------
@pytest.fixture
def chain():
"""Build all provenance triples for a complete graph-rag query."""
q_uri = question_uri(SESSION_ID)
gnd_uri = grounding_uri(SESSION_ID)
exp_uri = exploration_uri(SESSION_ID)
foc_uri = focus_uri(SESSION_ID)
syn_uri = synthesis_uri(SESSION_ID)
q = question_triples(q_uri, "What is quantum computing?", "2026-01-01T00:00:00Z")
gnd = grounding_triples(gnd_uri, q_uri, ["quantum", "computing"])
exp = exploration_triples(
exp_uri, gnd_uri, edge_count=42,
entities=["urn:entity:1", "urn:entity:2"],
)
foc = focus_triples(
foc_uri, exp_uri,
selected_edges_with_reasoning=[
{
"edge": (
"http://example.com/QuantumComputing",
"http://schema.org/relatedTo",
"http://example.com/Physics",
),
"reasoning": "Directly relevant to the query",
},
{
"edge": (
"http://example.com/QuantumComputing",
"http://schema.org/name",
"Quantum Computing",
),
"reasoning": "Provides the entity label",
},
],
session_id=SESSION_ID,
)
syn = synthesis_triples(syn_uri, foc_uri, document_id="urn:doc:answer-1")
return {
"uris": {
"question": q_uri,
"grounding": gnd_uri,
"exploration": exp_uri,
"focus": foc_uri,
"synthesis": syn_uri,
},
"triples": {
"question": q,
"grounding": gnd,
"exploration": exp,
"focus": foc,
"synthesis": syn,
},
"all": q + gnd + exp + foc + syn,
}
# ---------------------------------------------------------------------------
# Chain structure tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceChain:
"""Verify the full question → grounding → exploration → focus → synthesis chain."""
def test_chain_has_five_stages(self, chain):
"""Each stage should produce at least some triples."""
for stage in ["question", "grounding", "exploration", "focus", "synthesis"]:
assert len(chain["triples"][stage]) > 0, f"{stage} produced no triples"
def test_derivation_chain(self, chain):
"""
The wasDerivedFrom links must form:
grounding question, exploration grounding,
focus exploration, synthesis focus.
"""
uris = chain["uris"]
all_triples = chain["all"]
assert derived_from(all_triples, uris["grounding"]) == uris["question"]
assert derived_from(all_triples, uris["exploration"]) == uris["grounding"]
assert derived_from(all_triples, uris["focus"]) == uris["exploration"]
assert derived_from(all_triples, uris["synthesis"]) == uris["focus"]
def test_question_has_no_parent(self, chain):
"""The root question should not derive from anything (no parent_uri)."""
uris = chain["uris"]
all_triples = chain["all"]
assert derived_from(all_triples, uris["question"]) is None
def test_question_with_parent(self):
"""When a parent_uri is given, question should derive from it."""
q_uri = question_uri("child-session")
parent = "urn:trustgraph:agent:iteration:parent"
q = question_triples(q_uri, "sub-query", "2026-01-01T00:00:00Z",
parent_uri=parent)
assert derived_from(q, q_uri) == parent
# ---------------------------------------------------------------------------
# Type annotation tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceTypes:
"""Each stage must have the correct rdf:type annotations."""
def test_question_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["question"]
assert has_type(triples, uris["question"], PROV_ENTITY)
assert has_type(triples, uris["question"], TG_GRAPH_RAG_QUESTION)
def test_grounding_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["grounding"]
assert has_type(triples, uris["grounding"], PROV_ENTITY)
assert has_type(triples, uris["grounding"], TG_GROUNDING)
def test_exploration_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["exploration"]
assert has_type(triples, uris["exploration"], PROV_ENTITY)
assert has_type(triples, uris["exploration"], TG_EXPLORATION)
def test_focus_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["focus"]
assert has_type(triples, uris["focus"], PROV_ENTITY)
assert has_type(triples, uris["focus"], TG_FOCUS)
def test_synthesis_types(self, chain):
uris = chain["uris"]
triples = chain["triples"]["synthesis"]
assert has_type(triples, uris["synthesis"], PROV_ENTITY)
assert has_type(triples, uris["synthesis"], TG_SYNTHESIS)
assert has_type(triples, uris["synthesis"], TG_ANSWER_TYPE)
# ---------------------------------------------------------------------------
# Domain-specific content tests
# ---------------------------------------------------------------------------
class TestGraphRagProvenanceContent:
"""Each stage should carry the expected domain data."""
def test_question_has_query_text(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["question"], TG_QUERY, uris["question"])
assert t is not None
assert t.o.value == "What is quantum computing?"
def test_question_has_timestamp(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["question"], PROV_STARTED_AT_TIME, uris["question"])
assert t is not None
assert t.o.value == "2026-01-01T00:00:00Z"
def test_grounding_has_concepts(self, chain):
uris = chain["uris"]
concepts = find_triples(chain["triples"]["grounding"], TG_CONCEPT, uris["grounding"])
concept_values = {t.o.value for t in concepts}
assert concept_values == {"quantum", "computing"}
def test_exploration_has_edge_count(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["exploration"], TG_EDGE_COUNT, uris["exploration"])
assert t is not None
assert t.o.value == "42"
def test_exploration_has_entities(self, chain):
uris = chain["uris"]
entities = find_triples(chain["triples"]["exploration"], TG_ENTITY, uris["exploration"])
entity_iris = {t.o.iri for t in entities}
assert entity_iris == {"urn:entity:1", "urn:entity:2"}
def test_focus_has_selected_edges(self, chain):
uris = chain["uris"]
edges = find_triples(chain["triples"]["focus"], TG_SELECTED_EDGE, uris["focus"])
assert len(edges) == 2
def test_focus_edges_have_quoted_triples(self, chain):
"""Each edge selection entity should have a tg:edge with a quoted triple."""
focus = chain["triples"]["focus"]
edge_triples = find_triples(focus, TG_EDGE)
assert len(edge_triples) == 2
# Each should have a quoted triple as the object
for t in edge_triples:
assert t.o.triple is not None, "tg:edge object should be a quoted triple"
def test_focus_edges_have_reasoning(self, chain):
"""Each edge selection entity should have tg:reasoning."""
focus = chain["triples"]["focus"]
reasoning = find_triples(focus, TG_REASONING)
assert len(reasoning) == 2
reasoning_texts = {t.o.value for t in reasoning}
assert "Directly relevant to the query" in reasoning_texts
assert "Provides the entity label" in reasoning_texts
def test_synthesis_has_document_ref(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["synthesis"], TG_DOCUMENT, uris["synthesis"])
assert t is not None
assert t.o.iri == "urn:doc:answer-1"
def test_synthesis_has_labels(self, chain):
uris = chain["uris"]
t = find_triple(chain["triples"]["synthesis"], RDFS_LABEL, uris["synthesis"])
assert t is not None
assert t.o.value == "Synthesis"

View file

@ -500,7 +500,7 @@ class TestQuestionTriples:
def test_question_types(self):
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
assert has_type(triples, self.Q_URI, PROV_ENTITY)
assert has_type(triples, self.Q_URI, TG_QUESTION)
assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION)
@ -543,11 +543,11 @@ class TestGroundingTriples:
assert has_type(triples, self.GND_URI, PROV_ENTITY)
assert has_type(triples, self.GND_URI, TG_GROUNDING)
def test_grounding_generated_by_question(self):
def test_grounding_derived_from_question(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
assert gen is not None
assert gen.o.iri == self.Q_URI
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.GND_URI)
assert derived is not None
assert derived.o.iri == self.Q_URI
def test_grounding_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
@ -730,7 +730,7 @@ class TestDocRagQuestionTriples:
def test_docrag_question_types(self):
triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z")
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
assert has_type(triples, self.Q_URI, PROV_ENTITY)
assert has_type(triples, self.Q_URI, TG_QUESTION)
assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION)

View file

@ -0,0 +1,164 @@
"""
Tests for queue naming and topic mapping.
"""
import pytest
import argparse
from trustgraph.schema.core.topic import queue
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
from trustgraph.base.pulsar_backend import PulsarBackend
class TestQueueFunction:
def test_flow_default(self):
assert queue('text-completion-request') == 'flow:tg:text-completion-request'
def test_request_class(self):
assert queue('config', cls='request') == 'request:tg:config'
def test_response_class(self):
assert queue('config', cls='response') == 'response:tg:config'
def test_notify_class(self):
assert queue('config', cls='notify') == 'notify:tg:config'
def test_custom_topicspace(self):
assert queue('config', cls='request', topicspace='prod') == 'request:prod:config'
def test_default_class_is_flow(self):
result = queue('something')
assert result.startswith('flow:')
class TestPulsarMapTopic:
@pytest.fixture
def backend(self):
"""Create a PulsarBackend without connecting."""
b = object.__new__(PulsarBackend)
return b
def test_flow_maps_to_persistent(self, backend):
assert backend.map_topic('flow:tg:text-completion-request') == \
'persistent://tg/flow/text-completion-request'
def test_notify_maps_to_non_persistent(self, backend):
assert backend.map_topic('notify:tg:config') == \
'non-persistent://tg/notify/config'
def test_request_maps_to_non_persistent(self, backend):
assert backend.map_topic('request:tg:config') == \
'non-persistent://tg/request/config'
def test_response_maps_to_non_persistent(self, backend):
assert backend.map_topic('response:tg:librarian') == \
'non-persistent://tg/response/librarian'
def test_passthrough_pulsar_uri(self, backend):
uri = 'persistent://tg/flow/something'
assert backend.map_topic(uri) == uri
def test_invalid_format_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue format"):
backend.map_topic('bad-format')
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue class"):
backend.map_topic('unknown:tg:topic')
def test_custom_topicspace(self, backend):
assert backend.map_topic('flow:prod:my-queue') == \
'persistent://prod/flow/my-queue'
class TestGetPubsubDispatch:
def test_unknown_backend_raises(self):
with pytest.raises(ValueError, match="Unknown pub/sub backend"):
get_pubsub(pubsub_backend='redis')
class TestAddPubsubArgs:
def test_standalone_defaults_to_localhost(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args([])
assert args.pulsar_host == 'pulsar://localhost:6650'
assert args.pulsar_listener == 'localhost'
def test_non_standalone_defaults_to_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=False)
args = parser.parse_args([])
assert 'pulsar:6650' in args.pulsar_host
assert args.pulsar_listener is None
def test_cli_override_respected(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args(['--pulsar-host', 'pulsar://custom:6650'])
assert args.pulsar_host == 'pulsar://custom:6650'
def test_pubsub_backend_default(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.pubsub_backend == 'pulsar'
class TestAddPubsubArgsRabbitMQ:
def test_rabbitmq_args_present(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([
'--pubsub-backend', 'rabbitmq',
'--rabbitmq-host', 'myhost',
'--rabbitmq-port', '5673',
])
assert args.pubsub_backend == 'rabbitmq'
assert args.rabbitmq_host == 'myhost'
assert args.rabbitmq_port == 5673
def test_rabbitmq_defaults_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.rabbitmq_host == 'rabbitmq'
assert args.rabbitmq_port == 5672
assert args.rabbitmq_username == 'guest'
assert args.rabbitmq_password == 'guest'
assert args.rabbitmq_vhost == '/'
def test_rabbitmq_standalone_defaults_to_localhost(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args([])
assert args.rabbitmq_host == 'localhost'
class TestQueueDefinitions:
"""Verify the actual queue constants produce correct names."""
def test_config_request(self):
from trustgraph.schema.services.config import config_request_queue
assert config_request_queue == 'request:tg:config'
def test_config_response(self):
from trustgraph.schema.services.config import config_response_queue
assert config_response_queue == 'response:tg:config'
def test_config_push(self):
from trustgraph.schema.services.config import config_push_queue
assert config_push_queue == 'notify:tg:config'
def test_librarian_request(self):
from trustgraph.schema.services.library import librarian_request_queue
assert librarian_request_queue == 'request:tg:librarian'
def test_knowledge_request(self):
from trustgraph.schema.knowledge.knowledge import knowledge_request_queue
assert knowledge_request_queue == 'request:tg:knowledge'

View file

@ -0,0 +1,107 @@
"""
Unit tests for RabbitMQ backend queue name mapping and factory dispatch.
Does not require a running RabbitMQ instance.
"""
import pytest
import argparse
pika = pytest.importorskip("pika", reason="pika not installed")
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
class TestRabbitMQMapQueueName:
@pytest.fixture
def backend(self):
b = object.__new__(RabbitMQBackend)
return b
def test_flow_is_durable(self, backend):
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
assert durable is True
assert name == 'tg.flow.text-completion-request'
def test_notify_is_not_durable(self, backend):
name, durable = backend.map_queue_name('notify:tg:config')
assert durable is False
assert name == 'tg.notify.config'
def test_request_is_not_durable(self, backend):
name, durable = backend.map_queue_name('request:tg:config')
assert durable is False
assert name == 'tg.request.config'
def test_response_is_not_durable(self, backend):
name, durable = backend.map_queue_name('response:tg:librarian')
assert durable is False
assert name == 'tg.response.librarian'
def test_custom_topicspace(self, backend):
name, durable = backend.map_queue_name('flow:prod:my-queue')
assert name == 'prod.flow.my-queue'
assert durable is True
def test_no_colon_defaults_to_flow(self, backend):
name, durable = backend.map_queue_name('simple-queue')
assert name == 'tg.simple-queue'
assert durable is False
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue class"):
backend.map_queue_name('unknown:tg:topic')
def test_flow_with_flow_suffix(self, backend):
"""Queue names with flow suffix (e.g. :default) are preserved."""
name, durable = backend.map_queue_name('request:tg:prompt:default')
assert name == 'tg.request.prompt:default'
class TestGetPubsubRabbitMQ:
def test_factory_creates_rabbitmq_backend(self):
backend = get_pubsub(pubsub_backend='rabbitmq')
assert isinstance(backend, RabbitMQBackend)
def test_factory_passes_config(self):
backend = get_pubsub(
pubsub_backend='rabbitmq',
rabbitmq_host='myhost',
rabbitmq_port=5673,
rabbitmq_username='user',
rabbitmq_password='pass',
rabbitmq_vhost='/test',
)
assert isinstance(backend, RabbitMQBackend)
# Verify connection params were set
params = backend._connection_params
assert params.host == 'myhost'
assert params.port == 5673
assert params.virtual_host == '/test'
class TestAddPubsubArgsRabbitMQ:
def test_rabbitmq_args_present(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([
'--pubsub-backend', 'rabbitmq',
'--rabbitmq-host', 'myhost',
'--rabbitmq-port', '5673',
])
assert args.pubsub_backend == 'rabbitmq'
assert args.rabbitmq_host == 'myhost'
assert args.rabbitmq_port == 5673
def test_rabbitmq_defaults_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.rabbitmq_host == 'rabbitmq'
assert args.rabbitmq_port == 5672
assert args.rabbitmq_username == 'guest'
assert args.rabbitmq_password == 'guest'
assert args.rabbitmq_vhost == '/'

View file

@ -0,0 +1,424 @@
"""
Tests for SPARQL FILTER expression evaluator.
"""
import pytest
from trustgraph.schema import Term, IRI, LITERAL, BLANK
from trustgraph.query.sparql.expressions import (
evaluate_expression, _effective_boolean, _to_string, _to_numeric,
_comparable_value,
)
# --- Helpers ---
def iri(v):
return Term(type=IRI, iri=v)
def lit(v, datatype="", language=""):
return Term(type=LITERAL, value=v, datatype=datatype, language=language)
def blank(v):
return Term(type=BLANK, id=v)
XSD = "http://www.w3.org/2001/XMLSchema#"
class TestEvaluateExpression:
"""Test expression evaluation with rdflib algebra nodes."""
def test_variable_bound(self):
from rdflib.term import Variable
result = evaluate_expression(Variable("x"), {"x": lit("hello")})
assert result.value == "hello"
def test_variable_unbound(self):
from rdflib.term import Variable
result = evaluate_expression(Variable("x"), {})
assert result is None
def test_uriref_constant(self):
from rdflib import URIRef
result = evaluate_expression(
URIRef("http://example.com/a"), {}
)
assert result.type == IRI
assert result.iri == "http://example.com/a"
def test_literal_constant(self):
from rdflib import Literal
result = evaluate_expression(Literal("hello"), {})
assert result.type == LITERAL
assert result.value == "hello"
def test_boolean_constant(self):
assert evaluate_expression(True, {}) is True
assert evaluate_expression(False, {}) is False
def test_numeric_constant(self):
assert evaluate_expression(42, {}) == 42
assert evaluate_expression(3.14, {}) == 3.14
def test_none_returns_true(self):
assert evaluate_expression(None, {}) is True
class TestRelationalExpressions:
"""Test comparison operators via CompValue nodes."""
def _make_relational(self, left, op, right):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue("RelationalExpression",
expr=left, op=op, other=right)
def test_equal_literals(self):
from rdflib import Literal
expr = self._make_relational(Literal("a"), "=", Literal("a"))
assert evaluate_expression(expr, {}) is True
def test_not_equal_literals(self):
from rdflib import Literal
expr = self._make_relational(Literal("a"), "!=", Literal("b"))
assert evaluate_expression(expr, {}) is True
def test_less_than(self):
from rdflib import Literal
expr = self._make_relational(Literal("a"), "<", Literal("b"))
assert evaluate_expression(expr, {}) is True
def test_greater_than(self):
from rdflib import Literal
expr = self._make_relational(Literal("b"), ">", Literal("a"))
assert evaluate_expression(expr, {}) is True
def test_equal_with_variables(self):
from rdflib.term import Variable
expr = self._make_relational(Variable("x"), "=", Variable("y"))
sol = {"x": lit("same"), "y": lit("same")}
assert evaluate_expression(expr, sol) is True
def test_unequal_with_variables(self):
from rdflib.term import Variable
expr = self._make_relational(Variable("x"), "=", Variable("y"))
sol = {"x": lit("one"), "y": lit("two")}
assert evaluate_expression(expr, sol) is False
def test_none_operand_returns_false(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_relational(Variable("x"), "=", Literal("a"))
assert evaluate_expression(expr, {}) is False
class TestLogicalExpressions:
def _make_and(self, exprs):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue("ConditionalAndExpression",
expr=exprs[0], other=exprs[1:])
def _make_or(self, exprs):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue("ConditionalOrExpression",
expr=exprs[0], other=exprs[1:])
def _make_not(self, expr):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue("UnaryNot", expr=expr)
def test_and_true_true(self):
result = evaluate_expression(self._make_and([True, True]), {})
assert result is True
def test_and_true_false(self):
result = evaluate_expression(self._make_and([True, False]), {})
assert result is False
def test_or_false_true(self):
result = evaluate_expression(self._make_or([False, True]), {})
assert result is True
def test_or_false_false(self):
result = evaluate_expression(self._make_or([False, False]), {})
assert result is False
def test_not_true(self):
result = evaluate_expression(self._make_not(True), {})
assert result is False
def test_not_false(self):
result = evaluate_expression(self._make_not(False), {})
assert result is True
class TestBuiltinFunctions:
def _make_builtin(self, name, **kwargs):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue(f"Builtin_{name}", **kwargs)
def test_bound_true(self):
from rdflib.term import Variable
expr = self._make_builtin("BOUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("hi")}) is True
def test_bound_false(self):
from rdflib.term import Variable
expr = self._make_builtin("BOUND", arg=Variable("x"))
assert evaluate_expression(expr, {}) is False
def test_isiri_true(self):
from rdflib.term import Variable
expr = self._make_builtin("isIRI", arg=Variable("x"))
assert evaluate_expression(expr, {"x": iri("http://x")}) is True
def test_isiri_false(self):
from rdflib.term import Variable
expr = self._make_builtin("isIRI", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("hello")}) is False
def test_isliteral_true(self):
from rdflib.term import Variable
expr = self._make_builtin("isLITERAL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_isliteral_false(self):
from rdflib.term import Variable
expr = self._make_builtin("isLITERAL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": iri("http://x")}) is False
def test_isblank_true(self):
from rdflib.term import Variable
expr = self._make_builtin("isBLANK", arg=Variable("x"))
assert evaluate_expression(expr, {"x": blank("b1")}) is True
def test_isblank_false(self):
from rdflib.term import Variable
expr = self._make_builtin("isBLANK", arg=Variable("x"))
assert evaluate_expression(expr, {"x": iri("http://x")}) is False
def test_str(self):
from rdflib.term import Variable
expr = self._make_builtin("STR", arg=Variable("x"))
result = evaluate_expression(expr, {"x": iri("http://example.com/a")})
assert result.type == LITERAL
assert result.value == "http://example.com/a"
def test_lang(self):
from rdflib.term import Variable
expr = self._make_builtin("LANG", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("hello", language="en")}
)
assert result.value == "en"
def test_lang_no_tag(self):
from rdflib.term import Variable
expr = self._make_builtin("LANG", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == ""
def test_datatype(self):
from rdflib.term import Variable
expr = self._make_builtin("DATATYPE", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("42", datatype=XSD + "integer")}
)
assert result.type == IRI
assert result.iri == XSD + "integer"
def test_strlen(self):
from rdflib.term import Variable
expr = self._make_builtin("STRLEN", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result == 5
def test_ucase(self):
from rdflib.term import Variable
expr = self._make_builtin("UCASE", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == "HELLO"
def test_lcase(self):
from rdflib.term import Variable
expr = self._make_builtin("LCASE", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("HELLO")})
assert result.value == "hello"
def test_contains_true(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("CONTAINS",
arg1=Variable("x"), arg2=Literal("ell"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_contains_false(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("CONTAINS",
arg1=Variable("x"), arg2=Literal("xyz"))
assert evaluate_expression(expr, {"x": lit("hello")}) is False
def test_strstarts_true(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRSTARTS",
arg1=Variable("x"), arg2=Literal("hel"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_strends_true(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRENDS",
arg1=Variable("x"), arg2=Literal("llo"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_regex_match(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REGEX",
text=Variable("x"),
pattern=Literal("^hel"),
flags=None)
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_regex_case_insensitive(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REGEX",
text=Variable("x"),
pattern=Literal("HELLO"),
flags=Literal("i"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_regex_no_match(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REGEX",
text=Variable("x"),
pattern=Literal("^world"),
flags=None)
assert evaluate_expression(expr, {"x": lit("hello")}) is False
class TestEffectiveBoolean:
def test_true(self):
assert _effective_boolean(True) is True
def test_false(self):
assert _effective_boolean(False) is False
def test_none(self):
assert _effective_boolean(None) is False
def test_nonzero_int(self):
assert _effective_boolean(42) is True
def test_zero_int(self):
assert _effective_boolean(0) is False
def test_nonempty_string(self):
assert _effective_boolean("hello") is True
def test_empty_string(self):
assert _effective_boolean("") is False
def test_iri_term(self):
assert _effective_boolean(iri("http://x")) is True
def test_nonempty_literal(self):
assert _effective_boolean(lit("hello")) is True
def test_empty_literal(self):
assert _effective_boolean(lit("")) is False
def test_boolean_literal_true(self):
assert _effective_boolean(
lit("true", datatype=XSD + "boolean")
) is True
def test_boolean_literal_false(self):
assert _effective_boolean(
lit("false", datatype=XSD + "boolean")
) is False
def test_numeric_literal_nonzero(self):
assert _effective_boolean(
lit("42", datatype=XSD + "integer")
) is True
def test_numeric_literal_zero(self):
assert _effective_boolean(
lit("0", datatype=XSD + "integer")
) is False
class TestToString:
def test_none(self):
assert _to_string(None) == ""
def test_string(self):
assert _to_string("hello") == "hello"
def test_iri_term(self):
assert _to_string(iri("http://example.com")) == "http://example.com"
def test_literal_term(self):
assert _to_string(lit("hello")) == "hello"
def test_blank_term(self):
assert _to_string(blank("b1")) == "b1"
class TestToNumeric:
def test_none(self):
assert _to_numeric(None) is None
def test_int(self):
assert _to_numeric(42) == 42
def test_float(self):
assert _to_numeric(3.14) == 3.14
def test_integer_literal(self):
assert _to_numeric(lit("42")) == 42
def test_decimal_literal(self):
assert _to_numeric(lit("3.14")) == 3.14
def test_non_numeric_literal(self):
assert _to_numeric(lit("hello")) is None
def test_numeric_string(self):
assert _to_numeric("42") == 42
def test_non_numeric_string(self):
assert _to_numeric("abc") is None
class TestComparableValue:
def test_none(self):
assert _comparable_value(None) == (0, "")
def test_int(self):
assert _comparable_value(42) == (2, 42)
def test_iri(self):
assert _comparable_value(iri("http://x")) == (4, "http://x")
def test_literal(self):
assert _comparable_value(lit("hello")) == (3, "hello")
def test_numeric_literal(self):
assert _comparable_value(lit("42")) == (2, 42)
def test_ordering(self):
vals = [lit("b"), lit("a"), lit("c")]
sorted_vals = sorted(vals, key=_comparable_value)
assert sorted_vals[0].value == "a"
assert sorted_vals[1].value == "b"
assert sorted_vals[2].value == "c"

View file

@ -0,0 +1,205 @@
"""
Tests for the SPARQL parser module.
"""
import pytest
from trustgraph.query.sparql.parser import (
parse_sparql, ParseError, rdflib_term_to_term, term_to_rdflib,
)
from trustgraph.schema import Term, IRI, LITERAL, BLANK
class TestParseSparql:
"""Tests for parse_sparql function."""
def test_select_query_type(self):
parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
assert parsed.query_type == "select"
def test_select_variables(self):
parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
assert parsed.variables == ["s", "p", "o"]
def test_select_subset_variables(self):
parsed = parse_sparql("SELECT ?s ?o WHERE { ?s ?p ?o }")
assert parsed.variables == ["s", "o"]
def test_ask_query_type(self):
parsed = parse_sparql(
"ASK { <http://example.com/a> ?p ?o }"
)
assert parsed.query_type == "ask"
def test_ask_no_variables(self):
parsed = parse_sparql(
"ASK { <http://example.com/a> ?p ?o }"
)
assert parsed.variables == []
def test_construct_query_type(self):
parsed = parse_sparql(
"CONSTRUCT { ?s <http://example.com/knows> ?o } "
"WHERE { ?s <http://example.com/friendOf> ?o }"
)
assert parsed.query_type == "construct"
def test_describe_query_type(self):
parsed = parse_sparql(
"DESCRIBE <http://example.com/alice>"
)
assert parsed.query_type == "describe"
def test_select_with_limit(self):
parsed = parse_sparql(
"SELECT ?s WHERE { ?s ?p ?o } LIMIT 10"
)
assert parsed.query_type == "select"
assert parsed.variables == ["s"]
def test_select_with_distinct(self):
parsed = parse_sparql(
"SELECT DISTINCT ?s WHERE { ?s ?p ?o }"
)
assert parsed.query_type == "select"
assert parsed.variables == ["s"]
def test_select_with_filter(self):
parsed = parse_sparql(
'SELECT ?s ?label WHERE { '
' ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label . '
' FILTER(CONTAINS(STR(?label), "test")) '
'}'
)
assert parsed.query_type == "select"
assert parsed.variables == ["s", "label"]
def test_select_with_optional(self):
parsed = parse_sparql(
"SELECT ?s ?p ?o ?label WHERE { "
" ?s ?p ?o . "
" OPTIONAL { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
"}"
)
assert parsed.query_type == "select"
assert set(parsed.variables) == {"s", "p", "o", "label"}
def test_select_with_union(self):
parsed = parse_sparql(
"SELECT ?s ?label WHERE { "
" { ?s <http://example.com/name> ?label } "
" UNION "
" { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
"}"
)
assert parsed.query_type == "select"
def test_select_with_order_by(self):
parsed = parse_sparql(
"SELECT ?s ?label WHERE { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
"ORDER BY ?label"
)
assert parsed.query_type == "select"
def test_select_with_group_by(self):
parsed = parse_sparql(
"SELECT ?p (COUNT(?o) AS ?count) WHERE { ?s ?p ?o } "
"GROUP BY ?p ORDER BY DESC(?count)"
)
assert parsed.query_type == "select"
def test_select_with_prefixes(self):
parsed = parse_sparql(
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
"SELECT ?s ?label WHERE { ?s rdfs:label ?label }"
)
assert parsed.query_type == "select"
assert parsed.variables == ["s", "label"]
def test_algebra_not_none(self):
parsed = parse_sparql("SELECT ?s WHERE { ?s ?p ?o }")
assert parsed.algebra is not None
def test_parse_error_invalid_sparql(self):
with pytest.raises(ParseError):
parse_sparql("NOT VALID SPARQL AT ALL")
def test_parse_error_incomplete_query(self):
with pytest.raises(ParseError):
parse_sparql("SELECT ?s WHERE {")
def test_parse_error_message(self):
with pytest.raises(ParseError, match="SPARQL parse error"):
parse_sparql("GIBBERISH")
class TestRdflibTermToTerm:
"""Tests for rdflib-to-Term conversion."""
def test_uriref_to_term(self):
from rdflib import URIRef
term = rdflib_term_to_term(URIRef("http://example.com/alice"))
assert term.type == IRI
assert term.iri == "http://example.com/alice"
def test_literal_to_term(self):
from rdflib import Literal
term = rdflib_term_to_term(Literal("hello"))
assert term.type == LITERAL
assert term.value == "hello"
def test_typed_literal_to_term(self):
from rdflib import Literal, URIRef
term = rdflib_term_to_term(
Literal("42", datatype=URIRef("http://www.w3.org/2001/XMLSchema#integer"))
)
assert term.type == LITERAL
assert term.value == "42"
assert term.datatype == "http://www.w3.org/2001/XMLSchema#integer"
def test_lang_literal_to_term(self):
from rdflib import Literal
term = rdflib_term_to_term(Literal("hello", lang="en"))
assert term.type == LITERAL
assert term.value == "hello"
assert term.language == "en"
def test_bnode_to_term(self):
from rdflib import BNode
term = rdflib_term_to_term(BNode("b1"))
assert term.type == BLANK
assert term.id == "b1"
class TestTermToRdflib:
"""Tests for Term-to-rdflib conversion."""
def test_iri_term_to_uriref(self):
from rdflib import URIRef
result = term_to_rdflib(Term(type=IRI, iri="http://example.com/x"))
assert isinstance(result, URIRef)
assert str(result) == "http://example.com/x"
def test_literal_term_to_literal(self):
from rdflib import Literal
result = term_to_rdflib(Term(type=LITERAL, value="hello"))
assert isinstance(result, Literal)
assert str(result) == "hello"
def test_typed_literal_roundtrip(self):
from rdflib import URIRef
original = Term(
type=LITERAL, value="42",
datatype="http://www.w3.org/2001/XMLSchema#integer"
)
rdflib_term = term_to_rdflib(original)
assert rdflib_term.datatype == URIRef("http://www.w3.org/2001/XMLSchema#integer")
def test_lang_literal_roundtrip(self):
original = Term(type=LITERAL, value="bonjour", language="fr")
rdflib_term = term_to_rdflib(original)
assert rdflib_term.language == "fr"
def test_blank_term_to_bnode(self):
from rdflib import BNode
result = term_to_rdflib(Term(type=BLANK, id="b1"))
assert isinstance(result, BNode)

View file

@ -0,0 +1,345 @@
"""
Tests for SPARQL solution sequence operations.
"""
import pytest
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.solutions import (
hash_join, left_join, union, project, distinct,
order_by, slice_solutions, _terms_equal, _compatible,
)
# --- Test helpers ---
def iri(v):
return Term(type=IRI, iri=v)
def lit(v):
return Term(type=LITERAL, value=v)
# --- Fixtures ---
@pytest.fixture
def alice():
return iri("http://example.com/alice")
@pytest.fixture
def bob():
return iri("http://example.com/bob")
@pytest.fixture
def carol():
return iri("http://example.com/carol")
@pytest.fixture
def knows():
return iri("http://example.com/knows")
@pytest.fixture
def name_alice():
return lit("Alice")
@pytest.fixture
def name_bob():
return lit("Bob")
class TestTermsEqual:
def test_equal_iris(self):
assert _terms_equal(iri("http://x.com/a"), iri("http://x.com/a"))
def test_unequal_iris(self):
assert not _terms_equal(iri("http://x.com/a"), iri("http://x.com/b"))
def test_equal_literals(self):
assert _terms_equal(lit("hello"), lit("hello"))
def test_unequal_literals(self):
assert not _terms_equal(lit("hello"), lit("world"))
def test_iri_vs_literal(self):
assert not _terms_equal(iri("hello"), lit("hello"))
def test_none_none(self):
assert _terms_equal(None, None)
def test_none_vs_term(self):
assert not _terms_equal(None, iri("http://x.com/a"))
class TestCompatible:
def test_no_shared_variables(self):
assert _compatible({"a": iri("http://x")}, {"b": iri("http://y")})
def test_shared_variable_same_value(self, alice):
assert _compatible({"s": alice, "x": lit("1")}, {"s": alice, "y": lit("2")})
def test_shared_variable_different_value(self, alice, bob):
assert not _compatible({"s": alice}, {"s": bob})
def test_empty_solutions(self):
assert _compatible({}, {})
def test_empty_vs_nonempty(self, alice):
assert _compatible({}, {"s": alice})
class TestHashJoin:
def test_join_on_shared_variable(self, alice, bob, name_alice, name_bob):
left = [
{"s": alice, "p": iri("http://example.com/knows"), "o": bob},
{"s": bob, "p": iri("http://example.com/knows"), "o": alice},
]
right = [
{"s": alice, "label": name_alice},
{"s": bob, "label": name_bob},
]
result = hash_join(left, right)
assert len(result) == 2
# Check that joined solutions have all variables
for sol in result:
assert "s" in sol
assert "p" in sol
assert "o" in sol
assert "label" in sol
def test_join_no_shared_variables_cross_product(self, alice, bob):
left = [{"a": alice}]
right = [{"b": bob}, {"b": alice}]
result = hash_join(left, right)
assert len(result) == 2
def test_join_no_matches(self, alice, bob):
left = [{"s": alice}]
right = [{"s": bob}]
result = hash_join(left, right)
assert len(result) == 0
def test_join_empty_left(self, alice):
result = hash_join([], [{"s": alice}])
assert len(result) == 0
def test_join_empty_right(self, alice):
result = hash_join([{"s": alice}], [])
assert len(result) == 0
def test_join_multiple_matches(self, alice, name_alice):
left = [
{"s": alice, "p": iri("http://e.com/a")},
{"s": alice, "p": iri("http://e.com/b")},
]
right = [{"s": alice, "label": name_alice}]
result = hash_join(left, right)
assert len(result) == 2
def test_join_preserves_values(self, alice, name_alice):
left = [{"s": alice, "x": lit("1")}]
right = [{"s": alice, "y": lit("2")}]
result = hash_join(left, right)
assert len(result) == 1
assert result[0]["x"].value == "1"
assert result[0]["y"].value == "2"
class TestLeftJoin:
def test_left_join_with_matches(self, alice, bob, name_alice):
left = [{"s": alice}, {"s": bob}]
right = [{"s": alice, "label": name_alice}]
result = left_join(left, right)
assert len(result) == 2
# Alice has label
alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"]
assert len(alice_sols) == 1
assert "label" in alice_sols[0]
# Bob preserved without label
bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"]
assert len(bob_sols) == 1
assert "label" not in bob_sols[0]
def test_left_join_no_matches(self, alice, bob):
left = [{"s": alice}]
right = [{"s": bob, "label": lit("Bob")}]
result = left_join(left, right)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/alice"
assert "label" not in result[0]
def test_left_join_empty_right(self, alice):
left = [{"s": alice}]
result = left_join(left, [])
assert len(result) == 1
def test_left_join_empty_left(self):
result = left_join([], [{"s": iri("http://x")}])
assert len(result) == 0
def test_left_join_with_filter(self, alice, bob):
left = [{"s": alice}, {"s": bob}]
right = [
{"s": alice, "val": lit("yes")},
{"s": bob, "val": lit("no")},
]
# Filter: only keep joins where val == "yes"
result = left_join(
left, right,
filter_fn=lambda sol: sol.get("val") and sol["val"].value == "yes"
)
assert len(result) == 2
# Alice matches filter
alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"]
assert "val" in alice_sols[0]
assert alice_sols[0]["val"].value == "yes"
# Bob doesn't match filter, preserved without val
bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"]
assert "val" not in bob_sols[0]
class TestUnion:
def test_union_concatenates(self, alice, bob):
left = [{"s": alice}]
right = [{"s": bob}]
result = union(left, right)
assert len(result) == 2
def test_union_preserves_order(self, alice, bob):
left = [{"s": alice}]
right = [{"s": bob}]
result = union(left, right)
assert result[0]["s"].iri == "http://example.com/alice"
assert result[1]["s"].iri == "http://example.com/bob"
def test_union_empty_left(self, alice):
result = union([], [{"s": alice}])
assert len(result) == 1
def test_union_both_empty(self):
result = union([], [])
assert len(result) == 0
def test_union_allows_duplicates(self, alice):
result = union([{"s": alice}], [{"s": alice}])
assert len(result) == 2
class TestProject:
def test_project_keeps_selected(self, alice, name_alice):
solutions = [{"s": alice, "label": name_alice, "extra": lit("x")}]
result = project(solutions, ["s", "label"])
assert len(result) == 1
assert "s" in result[0]
assert "label" in result[0]
assert "extra" not in result[0]
def test_project_missing_variable(self, alice):
solutions = [{"s": alice}]
result = project(solutions, ["s", "missing"])
assert len(result) == 1
assert "s" in result[0]
assert "missing" not in result[0]
def test_project_empty(self):
result = project([], ["s"])
assert len(result) == 0
class TestDistinct:
def test_removes_duplicates(self, alice):
solutions = [{"s": alice}, {"s": alice}, {"s": alice}]
result = distinct(solutions)
assert len(result) == 1
def test_keeps_different(self, alice, bob):
solutions = [{"s": alice}, {"s": bob}]
result = distinct(solutions)
assert len(result) == 2
def test_empty(self):
result = distinct([])
assert len(result) == 0
def test_multi_variable_distinct(self, alice, bob):
solutions = [
{"s": alice, "o": bob},
{"s": alice, "o": bob},
{"s": alice, "o": alice},
]
result = distinct(solutions)
assert len(result) == 2
class TestOrderBy:
def test_order_by_ascending(self):
solutions = [
{"label": lit("Charlie")},
{"label": lit("Alice")},
{"label": lit("Bob")},
]
key_fns = [(lambda sol: sol.get("label"), True)]
result = order_by(solutions, key_fns)
assert result[0]["label"].value == "Alice"
assert result[1]["label"].value == "Bob"
assert result[2]["label"].value == "Charlie"
def test_order_by_descending(self):
solutions = [
{"label": lit("Alice")},
{"label": lit("Charlie")},
{"label": lit("Bob")},
]
key_fns = [(lambda sol: sol.get("label"), False)]
result = order_by(solutions, key_fns)
assert result[0]["label"].value == "Charlie"
assert result[1]["label"].value == "Bob"
assert result[2]["label"].value == "Alice"
def test_order_by_empty(self):
result = order_by([], [(lambda sol: sol.get("x"), True)])
assert len(result) == 0
def test_order_by_no_keys(self, alice):
solutions = [{"s": alice}]
result = order_by(solutions, [])
assert len(result) == 1
class TestSlice:
def test_limit(self, alice, bob, carol):
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
result = slice_solutions(solutions, limit=2)
assert len(result) == 2
def test_offset(self, alice, bob, carol):
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
result = slice_solutions(solutions, offset=1)
assert len(result) == 2
assert result[0]["s"].iri == "http://example.com/bob"
def test_offset_and_limit(self, alice, bob, carol):
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
result = slice_solutions(solutions, offset=1, limit=1)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/bob"
def test_limit_zero(self, alice):
result = slice_solutions([{"s": alice}], limit=0)
assert len(result) == 0
def test_offset_beyond_length(self, alice):
result = slice_solutions([{"s": alice}], offset=10)
assert len(result) == 0
def test_no_slice(self, alice, bob):
solutions = [{"s": alice}, {"s": bob}]
result = slice_solutions(solutions)
assert len(result) == 2

View file

@ -28,21 +28,21 @@ def triple_tx():
class TestTermTranslatorIri:
def test_iri_to_pulsar(self, term_tx):
def test_iri_decode(self, term_tx):
data = {"t": "i", "i": "http://example.org/Alice"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == IRI
assert term.iri == "http://example.org/Alice"
def test_iri_from_pulsar(self, term_tx):
def test_iri_encode(self, term_tx):
term = Term(type=IRI, iri="http://example.org/Bob")
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire == {"t": "i", "i": "http://example.org/Bob"}
def test_iri_round_trip(self, term_tx):
original = Term(type=IRI, iri="http://example.org/round")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored == original
@ -52,21 +52,21 @@ class TestTermTranslatorIri:
class TestTermTranslatorBlank:
def test_blank_to_pulsar(self, term_tx):
def test_blank_decode(self, term_tx):
data = {"t": "b", "d": "_:b42"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == BLANK
assert term.id == "_:b42"
def test_blank_from_pulsar(self, term_tx):
def test_blank_encode(self, term_tx):
term = Term(type=BLANK, id="_:node1")
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire == {"t": "b", "d": "_:node1"}
def test_blank_round_trip(self, term_tx):
original = Term(type=BLANK, id="_:x")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored == original
@ -76,29 +76,29 @@ class TestTermTranslatorBlank:
class TestTermTranslatorTypedLiteral:
def test_plain_literal_to_pulsar(self, term_tx):
def test_plain_literal_decode(self, term_tx):
data = {"t": "l", "v": "hello"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == LITERAL
assert term.value == "hello"
assert term.datatype == ""
assert term.language == ""
def test_xsd_integer_to_pulsar(self, term_tx):
def test_xsd_integer_decode(self, term_tx):
data = {
"t": "l", "v": "42",
"dt": "http://www.w3.org/2001/XMLSchema#integer",
}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.value == "42"
assert term.datatype.endswith("#integer")
def test_typed_literal_from_pulsar(self, term_tx):
def test_typed_literal_encode(self, term_tx):
term = Term(
type=LITERAL, value="3.14",
datatype="http://www.w3.org/2001/XMLSchema#double",
)
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire["t"] == "l"
assert wire["v"] == "3.14"
assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double"
@ -109,13 +109,13 @@ class TestTermTranslatorTypedLiteral:
type=LITERAL, value="true",
datatype="http://www.w3.org/2001/XMLSchema#boolean",
)
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored == original
def test_plain_literal_omits_dt_and_ln(self, term_tx):
term = Term(type=LITERAL, value="x")
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert "dt" not in wire
assert "ln" not in wire
@ -126,22 +126,22 @@ class TestTermTranslatorTypedLiteral:
class TestTermTranslatorLangLiteral:
def test_language_tag_to_pulsar(self, term_tx):
def test_language_tag_decode(self, term_tx):
data = {"t": "l", "v": "bonjour", "ln": "fr"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.value == "bonjour"
assert term.language == "fr"
def test_language_tag_from_pulsar(self, term_tx):
def test_language_tag_encode(self, term_tx):
term = Term(type=LITERAL, value="colour", language="en-GB")
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire["ln"] == "en-GB"
assert "dt" not in wire # No datatype
def test_language_tag_round_trip(self, term_tx):
original = Term(type=LITERAL, value="hola", language="es")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored == original
@ -151,7 +151,7 @@ class TestTermTranslatorLangLiteral:
class TestTermTranslatorQuotedTriple:
def test_quoted_triple_to_pulsar(self, term_tx):
def test_quoted_triple_decode(self, term_tx):
data = {
"t": "t",
"tr": {
@ -160,20 +160,20 @@ class TestTermTranslatorQuotedTriple:
"o": {"t": "i", "i": "http://example.org/Bob"},
},
}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == TRIPLE
assert term.triple is not None
assert term.triple.s.iri == "http://example.org/Alice"
assert term.triple.o.iri == "http://example.org/Bob"
def test_quoted_triple_from_pulsar(self, term_tx):
def test_quoted_triple_encode(self, term_tx):
inner = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="val"),
)
term = Term(type=TRIPLE, triple=inner)
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire["t"] == "t"
assert "tr" in wire
assert wire["tr"]["s"]["i"] == "http://example.org/s"
@ -186,18 +186,18 @@ class TestTermTranslatorQuotedTriple:
o=Term(type=LITERAL, value="C", language="en"),
)
original = Term(type=TRIPLE, triple=inner)
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored.type == TRIPLE
assert restored.triple.s == original.triple.s
assert restored.triple.o == original.triple.o
def test_quoted_triple_none_triple(self, term_tx):
term = Term(type=TRIPLE, triple=None)
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire == {"t": "t"}
# And back
restored = term_tx.to_pulsar(wire)
restored = term_tx.decode(wire)
assert restored.type == TRIPLE
assert restored.triple is None
@ -210,7 +210,7 @@ class TestTermTranslatorQuotedTriple:
"o": {"t": "l", "v": "A feeling of expectation"},
},
}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.triple.o.type == LITERAL
assert term.triple.o.value == "A feeling of expectation"
@ -223,22 +223,22 @@ class TestTermTranslatorEdgeCases:
def test_unknown_type(self, term_tx):
data = {"t": "z"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == "z"
def test_empty_type(self, term_tx):
data = {}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == ""
def test_missing_iri_field(self, term_tx):
data = {"t": "i"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.iri == ""
def test_missing_literal_fields(self, term_tx):
data = {"t": "l"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.value == ""
assert term.datatype == ""
assert term.language == ""
@ -250,24 +250,24 @@ class TestTermTranslatorEdgeCases:
class TestTripleTranslator:
def test_triple_to_pulsar(self, triple_tx):
def test_triple_decode(self, triple_tx):
data = {
"s": {"t": "i", "i": "http://example.org/s"},
"p": {"t": "i", "i": "http://example.org/p"},
"o": {"t": "l", "v": "object"},
}
triple = triple_tx.to_pulsar(data)
triple = triple_tx.decode(data)
assert triple.s.iri == "http://example.org/s"
assert triple.o.value == "object"
assert triple.g is None
def test_triple_from_pulsar(self, triple_tx):
def test_triple_encode(self, triple_tx):
triple = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/B"),
o=Term(type=LITERAL, value="C"),
)
wire = triple_tx.from_pulsar(triple)
wire = triple_tx.encode(triple)
assert wire["s"]["t"] == "i"
assert wire["o"]["v"] == "C"
assert "g" not in wire
@ -279,17 +279,17 @@ class TestTripleTranslator:
"o": {"t": "l", "v": "val"},
"g": "urn:graph:source",
}
quad = triple_tx.to_pulsar(data)
quad = triple_tx.decode(data)
assert quad.g == "urn:graph:source"
def test_quad_from_pulsar_includes_graph(self, triple_tx):
def test_quad_encode_includes_graph(self, triple_tx):
quad = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="v"),
g="urn:graph:retrieval",
)
wire = triple_tx.from_pulsar(quad)
wire = triple_tx.encode(quad)
assert wire["g"] == "urn:graph:retrieval"
def test_quad_round_trip(self, triple_tx):
@ -299,8 +299,8 @@ class TestTripleTranslator:
o=Term(type=LITERAL, value="v"),
g="urn:graph:source",
)
wire = triple_tx.from_pulsar(original)
restored = triple_tx.to_pulsar(wire)
wire = triple_tx.encode(original)
restored = triple_tx.decode(wire)
assert restored == original
def test_none_graph_omitted_from_wire(self, triple_tx):
@ -310,12 +310,12 @@ class TestTripleTranslator:
o=Term(type=LITERAL, value="v"),
g=None,
)
wire = triple_tx.from_pulsar(triple)
wire = triple_tx.encode(triple)
assert "g" not in wire
def test_missing_terms_handled(self, triple_tx):
data = {}
triple = triple_tx.to_pulsar(data)
triple = triple_tx.decode(data)
assert triple.s is None
assert triple.p is None
assert triple.o is None
@ -342,16 +342,16 @@ class TestSubgraphTranslator:
g="urn:graph:source",
),
]
wire_list = tx.from_pulsar(triples)
wire_list = tx.encode(triples)
assert len(wire_list) == 2
assert wire_list[1]["g"] == "urn:graph:source"
restored = tx.to_pulsar(wire_list)
restored = tx.decode(wire_list)
assert len(restored) == 2
assert restored[0] == triples[0]
assert restored[1] == triples[1]
def test_empty_subgraph(self):
tx = SubgraphTranslator()
assert tx.to_pulsar([]) == []
assert tx.from_pulsar([]) == []
assert tx.decode([]) == []
assert tx.encode([]) == []

View file

@ -35,7 +35,7 @@ class TestDocumentMetadataTranslator:
"parent-id": "doc-100",
"document-type": "page",
}
obj = self.tx.to_pulsar(data)
obj = self.tx.decode(data)
assert obj.id == "doc-123"
assert obj.time == 1710000000
assert obj.kind == "application/pdf"
@ -45,14 +45,14 @@ class TestDocumentMetadataTranslator:
assert obj.parent_id == "doc-100"
assert obj.document_type == "page"
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert wire["id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["parent-id"] == "doc-100"
assert wire["document-type"] == "page"
def test_defaults_for_missing_fields(self):
obj = self.tx.to_pulsar({})
obj = self.tx.decode({})
assert obj.parent_id == ""
assert obj.document_type == "source"
@ -63,25 +63,25 @@ class TestDocumentMetadataTranslator:
"o": {"t": "i", "i": "http://example.org/o"},
}]
data = {"metadata": triple_wire}
obj = self.tx.to_pulsar(data)
obj = self.tx.decode(data)
assert len(obj.metadata) == 1
assert obj.metadata[0].s.iri == "http://example.org/s"
def test_none_metadata_handled(self):
data = {"metadata": None}
obj = self.tx.to_pulsar(data)
obj = self.tx.decode(data)
assert obj.metadata == []
def test_empty_tags_preserved(self):
data = {"tags": []}
obj = self.tx.to_pulsar(data)
wire = self.tx.from_pulsar(obj)
obj = self.tx.decode(data)
wire = self.tx.encode(obj)
assert wire["tags"] == []
def test_falsy_fields_omitted_from_wire(self):
"""Empty string fields should be omitted from wire format."""
obj = DocumentMetadata(id="", time=0, user="")
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert "id" not in wire
assert "user" not in wire
@ -105,7 +105,7 @@ class TestProcessingMetadataTranslator:
"collection": "my-collection",
"tags": ["tag1"],
}
obj = self.tx.to_pulsar(data)
obj = self.tx.decode(data)
assert obj.id == "proc-1"
assert obj.document_id == "doc-123"
assert obj.flow == "default"
@ -113,32 +113,32 @@ class TestProcessingMetadataTranslator:
assert obj.collection == "my-collection"
assert obj.tags == ["tag1"]
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert wire["id"] == "proc-1"
assert wire["document-id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["collection"] == "my-collection"
def test_missing_fields_use_defaults(self):
obj = self.tx.to_pulsar({})
obj = self.tx.decode({})
assert obj.id is None
assert obj.user is None
assert obj.collection is None
def test_tags_none_omitted(self):
obj = ProcessingMetadata(tags=None)
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert "tags" not in wire
def test_tags_empty_list_preserved(self):
obj = ProcessingMetadata(tags=[])
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert wire["tags"] == []
def test_user_and_collection_preserved(self):
"""Core pipeline routing fields must survive round-trip."""
data = {"user": "bob", "collection": "research"}
obj = self.tx.to_pulsar(data)
wire = self.tx.from_pulsar(obj)
obj = self.tx.decode(data)
wire = self.tx.encode(obj)
assert wire["user"] == "bob"
assert wire["collection"] == "research"

View file

@ -0,0 +1,380 @@
"""
Integration test: run a full DocumentRag.query() with mocked subsidiary
clients and verify the explain_callback receives the complete provenance
chain in the correct order with correct structure.
Document-RAG provenance chain (4 stages):
question grounding exploration synthesis
"""
import pytest
from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT,
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class ChunkMatch:
"""Mimics the result from doc_embeddings_client.query()."""
chunk_id: str
# ---------------------------------------------------------------------------
# Mock setup
# ---------------------------------------------------------------------------
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
CHUNK_B_CONTENT = "Refunds are processed to the original payment method."
def build_mock_clients():
"""
Build mock clients for a document-rag query.
Client call sequence during query():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. doc_embeddings_client.query(vector, ...) -> chunk matches
4. fetch_chunk(chunk_id, user) -> chunk content
5. prompt_client.document_prompt(query, documents) -> answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock()
# 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return "return policy\nrefund"
return ""
prompt_client.prompt.side_effect = mock_prompt
# 2. Embedding vectors
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# 3. Chunk matching
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id=CHUNK_A),
ChunkMatch(chunk_id=CHUNK_B),
]
# 4. Chunk content
async def mock_fetch(chunk_id, user):
return {
CHUNK_A: CHUNK_A_CONTENT,
CHUNK_B: CHUNK_B_CONTENT,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
# 5. Synthesis
prompt_client.document_prompt.return_value = (
"Items can be returned within 30 days for a full refund."
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestDocumentRagQueryProvenance:
"""
Run a real DocumentRag.query() and verify the provenance chain emitted
via explain_callback.
"""
@pytest.mark.asyncio
async def test_explain_callback_receives_four_events(self):
"""query() should emit exactly 4 explain events."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
assert len(events) == 4, (
f"Expected 4 explain events (question, grounding, exploration, "
f"synthesis), got {len(events)}"
)
@pytest.mark.asyncio
async def test_events_have_correct_types_in_order(self):
"""
Events should arrive as:
question, grounding, exploration, synthesis.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
expected_types = [
TG_DOC_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_SYNTHESIS,
]
for i, expected_type in enumerate(expected_types):
uri = events[i]["explain_id"]
triples = events[i]["triples"]
assert has_type(triples, uri, expected_type), (
f"Event {i} (uri={uri}) should have type {expected_type}"
)
@pytest.mark.asyncio
async def test_derivation_chain_links_correctly(self):
"""
Each event's URI should link to the previous via wasDerivedFrom:
question (none)
grounding question
exploration grounding
synthesis exploration
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
uris = [e["explain_id"] for e in events]
all_triples = []
for e in events:
all_triples.extend(e["triples"])
# question has no parent
assert derived_from(all_triples, uris[0]) is None
# grounding → question
assert derived_from(all_triples, uris[1]) == uris[0]
# exploration → grounding
assert derived_from(all_triples, uris[2]) == uris[1]
# synthesis → exploration
assert derived_from(all_triples, uris[3]) == uris[2]
@pytest.mark.asyncio
async def test_question_carries_query_text(self):
"""The question event should contain the original query string."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
t = find_triple(q_triples, TG_QUERY, q_uri)
assert t is not None
assert t.o.value == "What is the return policy?"
@pytest.mark.asyncio
async def test_grounding_carries_concepts(self):
"""The grounding event should list extracted concepts."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
gnd_uri = events[1]["explain_id"]
gnd_triples = events[1]["triples"]
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
concept_values = {t.o.value for t in concepts}
assert "return policy" in concept_values
assert "refund" in concept_values
@pytest.mark.asyncio
async def test_exploration_has_chunk_count(self):
"""The exploration event should report the number of chunks retrieved."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
t = find_triple(exp_triples, TG_CHUNK_COUNT, exp_uri)
assert t is not None
assert int(t.o.value) == 2
@pytest.mark.asyncio
async def test_exploration_has_selected_chunks(self):
"""The exploration event should list the chunk IDs that were fetched."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
chunks = find_triples(exp_triples, TG_SELECTED_CHUNK, exp_uri)
chunk_iris = {t.o.iri for t in chunks}
assert CHUNK_A in chunk_iris
assert CHUNK_B in chunk_iris
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
"""The synthesis event should have tg:Synthesis and tg:Answer types."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
syn_uri = events[3]["explain_id"]
syn_triples = events[3]["triples"]
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
@pytest.mark.asyncio
async def test_query_returns_answer_text(self):
"""query() should return the synthesised answer."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(
query="What is the return policy?",
explain_callback=AsyncMock(),
)
assert result == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
"""query() without explain_callback should return answer normally."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(query="What is the return policy?")
assert result == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All emitted triples should be in the urn:graph:retrieval graph."""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
for event in events:
for t in event["triples"]:
assert t.g == "urn:graph:retrieval", (
f"Triple {t.s.iri} {t.p.iri} should be in "
f"urn:graph:retrieval, got {t.g}"
)

View file

@ -465,12 +465,15 @@ class TestQuery:
return_value=(["entity1", "entity2"], ["concept1"])
)
query.follow_edges_batch = AsyncMock(return_value={
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
})
query.follow_edges_batch = AsyncMock(return_value=(
{
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
},
{}
))
subgraph, entities, concepts = await query.get_subgraph("test query")
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
query.get_entities.assert_called_once_with("test query")
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
@ -503,7 +506,7 @@ class TestQuery:
test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"]
query.get_subgraph = AsyncMock(
return_value=(test_subgraph, test_entities, test_concepts)
return_value=(test_subgraph, {}, test_entities, test_concepts)
)
async def mock_maybe_label(entity):

View file

@ -0,0 +1,358 @@
"""
Tests that explain_triples are forwarded correctly through the graph-rag
service and client layers.
Covers:
- Service: explain messages include triples from the provenance callback
- Client: explain_callback receives explain_triples from the response
- End-to-end: triples survive the full service client callback chain
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.schema import (
GraphRagQuery, GraphRagResponse,
Triple, Term, IRI, LITERAL,
)
from trustgraph.base.graph_rag_client import GraphRagClient
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_triple(s_iri, p_iri, o_value, o_type=IRI):
"""Create a Triple with IRI subject/predicate and typed object."""
o = (
Term(type=IRI, iri=o_value) if o_type == IRI
else Term(type=LITERAL, value=o_value)
)
return Triple(
s=Term(type=IRI, iri=s_iri),
p=Term(type=IRI, iri=p_iri),
o=o,
)
def sample_focus_triples():
"""Focus-style triples with a quoted triple (edge selection)."""
return [
make_triple(
"urn:trustgraph:focus:abc",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"https://trustgraph.ai/ns/Focus",
),
make_triple(
"urn:trustgraph:focus:abc",
"http://www.w3.org/ns/prov#wasDerivedFrom",
"urn:trustgraph:exploration:abc",
),
make_triple(
"urn:trustgraph:focus:abc",
"https://trustgraph.ai/ns/selectedEdge",
"urn:trustgraph:edge-sel:abc:0",
),
]
def sample_question_triples():
"""Question-style triples."""
return [
make_triple(
"urn:trustgraph:question:abc",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"https://trustgraph.ai/ns/GraphRagQuestion",
),
make_triple(
"urn:trustgraph:question:abc",
"https://trustgraph.ai/ns/query",
"What is quantum computing?",
o_type=LITERAL,
),
]
# ---------------------------------------------------------------------------
# Service-level: explain messages carry triples
# ---------------------------------------------------------------------------
class TestGraphRagServiceExplainTriples:
"""Test that the graph-rag service includes explain_triples in messages."""
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
@pytest.mark.asyncio
async def test_explain_messages_include_triples(self, mock_graph_rag_class):
"""
When the provenance callback is invoked with triples, the service
should include them in the explain response message.
"""
from trustgraph.retrieval.graph_rag.rag import Processor
processor = Processor(
taskgroup=MagicMock(),
id="test-processor",
entity_limit=50,
triple_limit=30,
max_subgraph_size=150,
max_path_length=2,
)
mock_rag_instance = AsyncMock()
mock_graph_rag_class.return_value = mock_rag_instance
question_triples = sample_question_triples()
focus_triples = sample_focus_triples()
async def mock_query(**kwargs):
explain_callback = kwargs.get('explain_callback')
if explain_callback:
await explain_callback(
question_triples, "urn:trustgraph:question:abc"
)
await explain_callback(
focus_triples, "urn:trustgraph:focus:abc"
)
return "The answer."
mock_rag_instance.query.side_effect = mock_query
msg = MagicMock()
msg.value.return_value = GraphRagQuery(
query="What is quantum computing?",
user="trustgraph",
collection="default",
streaming=False,
)
msg.properties.return_value = {"id": "test-id"}
consumer = MagicMock()
flow = MagicMock()
mock_response = AsyncMock()
mock_provenance = AsyncMock()
def flow_router(name):
if name == "response":
return mock_response
if name == "explainability":
return mock_provenance
return AsyncMock()
flow.side_effect = flow_router
await processor.on_request(msg, consumer, flow)
# Find the explain messages
explain_msgs = [
call[0][0]
for call in mock_response.send.call_args_list
if call[0][0].message_type == "explain"
]
assert len(explain_msgs) == 2
# First explain message should carry question triples
assert explain_msgs[0].explain_id == "urn:trustgraph:question:abc"
assert explain_msgs[0].explain_triples == question_triples
# Second explain message should carry focus triples
assert explain_msgs[1].explain_id == "urn:trustgraph:focus:abc"
assert explain_msgs[1].explain_triples == focus_triples
# ---------------------------------------------------------------------------
# Client-level: explain_callback receives triples
# ---------------------------------------------------------------------------
class TestGraphRagClientExplainForwarding:
"""Test that GraphRagClient.rag() forwards explain_triples to callback."""
@pytest.mark.asyncio
async def test_explain_callback_receives_triples(self):
"""
The explain_callback should receive (explain_id, explain_graph,
explain_triples) not just (explain_id, explain_graph).
"""
focus_triples = sample_focus_triples()
# Simulate the response sequence the client would receive
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:focus:abc",
explain_graph="urn:graph:retrieval",
explain_triples=focus_triples,
),
GraphRagResponse(
message_type="chunk",
response="The answer.",
end_of_stream=True,
),
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
),
]
# Capture what the explain_callback receives
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append({
"explain_id": explain_id,
"explain_graph": explain_graph,
"explain_triples": explain_triples,
})
# Patch self.request to feed responses to the recipient
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
result = await client.rag(
query="test",
explain_callback=explain_callback,
)
assert result == "The answer."
assert len(received_calls) == 1
assert received_calls[0]["explain_id"] == "urn:trustgraph:focus:abc"
assert received_calls[0]["explain_graph"] == "urn:graph:retrieval"
assert received_calls[0]["explain_triples"] == focus_triples
@pytest.mark.asyncio
async def test_explain_callback_receives_empty_triples(self):
"""
When an explain event has no triples, the callback should still
receive an empty list (not None or missing).
"""
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=[],
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append(explain_triples)
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
await client.rag(query="test", explain_callback=explain_callback)
assert len(received_calls) == 1
assert received_calls[0] == []
@pytest.mark.asyncio
async def test_multiple_explain_events_all_forward_triples(self):
"""
Each explain event in a session should forward its own triples.
"""
q_triples = sample_question_triples()
f_triples = sample_focus_triples()
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=q_triples,
),
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:focus:abc",
explain_graph="urn:graph:retrieval",
explain_triples=f_triples,
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
received_calls = []
async def explain_callback(explain_id, explain_graph, explain_triples):
received_calls.append({
"explain_id": explain_id,
"explain_triples": explain_triples,
})
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
await client.rag(query="test", explain_callback=explain_callback)
assert len(received_calls) == 2
assert received_calls[0]["explain_id"] == "urn:trustgraph:question:abc"
assert received_calls[0]["explain_triples"] == q_triples
assert received_calls[1]["explain_id"] == "urn:trustgraph:focus:abc"
assert received_calls[1]["explain_triples"] == f_triples
@pytest.mark.asyncio
async def test_no_explain_callback_does_not_error(self):
"""
When no explain_callback is provided, explain events should be
silently skipped without errors.
"""
responses = [
GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc",
explain_graph="urn:graph:retrieval",
explain_triples=sample_question_triples(),
),
GraphRagResponse(
message_type="chunk",
response="Answer.",
end_of_stream=True,
end_of_session=True,
),
]
client = GraphRagClient.__new__(GraphRagClient)
async def mock_request(req, timeout=600, recipient=None):
for resp in responses:
done = await recipient(resp)
if done:
return resp
client.request = mock_request
result = await client.rag(query="test")
assert result == "Answer."

View file

@ -0,0 +1,482 @@
"""
Integration test: run a full GraphRag.query() with mocked subsidiary clients
and verify the explain_callback receives the complete provenance chain
in the correct order with correct structure.
This tests the real query() method end-to-end, not just the triple builders.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class EmbeddingMatch:
"""Mimics the result from graph_embeddings_client.query()."""
entity: Term
# ---------------------------------------------------------------------------
# Mock setup
# ---------------------------------------------------------------------------
# A tiny knowledge graph: 2 entities, 3 edges
ENTITY_A = "http://example.com/QuantumComputing"
ENTITY_B = "http://example.com/Physics"
EDGE_1 = (ENTITY_A, "http://schema.org/relatedTo", ENTITY_B)
EDGE_2 = (ENTITY_A, "http://schema.org/name", "Quantum Computing")
EDGE_3 = (ENTITY_B, "http://schema.org/name", "Physics")
def make_schema_triple(s, p, o):
"""Create a SchemaTriple from string values."""
return SchemaTriple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o) if o.startswith("http") else Term(type=LITERAL, value=o),
)
def build_mock_clients():
"""
Build mock clients that simulate a small knowledge graph query.
Client call sequence during query():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. graph_embeddings_client.query(vector, ...) -> entity matches
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
# 1. Concept extraction
prompt_responses = {}
prompt_responses["extract-concepts"] = "quantum computing\nphysics"
# 2. Embedding vectors (simple fake vectors)
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# 3. Entity lookup - return our two entities
graph_embeddings_client.query.return_value = [
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_A)),
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
]
# 4. Triple queries (follow_edges_batch) - return our edges
kg_triples = [
make_schema_triple(*EDGE_1),
make_schema_triple(*EDGE_2),
make_schema_triple(*EDGE_3),
]
triples_client.query_stream.return_value = kg_triples
# 5. Label resolution - return entity as its own label (simplify)
async def mock_label_query(s=None, p=None, o=None, limit=1,
user=None, collection=None, g=None):
return [] # No labels found, will fall back to URI
triples_client.query.side_effect = mock_label_query
# 6+7. Edge scoring and reasoning: dynamically score/reason about
# whatever edges the query method sends us, since edge IDs are computed
# from str(Term) representations which include the full dataclass repr.
synthesis_answer = "Quantum computing applies physics principles to computation."
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return prompt_responses["extract-concepts"]
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
]
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
]
elif template_id == "kg-synthesis":
return synthesis_answer
return ""
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestGraphRagQueryProvenance:
"""
Run a real GraphRag.query() and verify the provenance chain emitted
via explain_callback.
"""
@pytest.mark.asyncio
async def test_explain_callback_receives_five_events(self):
"""query() should emit exactly 5 explain events."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0, # skip semantic pre-filter for simplicity
)
assert len(events) == 5, (
f"Expected 5 explain events (question, grounding, exploration, "
f"focus, synthesis), got {len(events)}"
)
@pytest.mark.asyncio
async def test_events_have_correct_types_in_order(self):
"""
Events should arrive as:
question, grounding, exploration, focus, synthesis.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
expected_types = [
TG_GRAPH_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_FOCUS,
TG_SYNTHESIS,
]
for i, expected_type in enumerate(expected_types):
uri = events[i]["explain_id"]
triples = events[i]["triples"]
assert has_type(triples, uri, expected_type), (
f"Event {i} (uri={uri}) should have type {expected_type}"
)
@pytest.mark.asyncio
async def test_derivation_chain_links_correctly(self):
"""
Each event's URI should link to the previous via wasDerivedFrom:
grounding question (none)
exploration grounding
focus exploration
synthesis focus
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
uris = [e["explain_id"] for e in events]
all_triples = []
for e in events:
all_triples.extend(e["triples"])
# question has no parent
assert derived_from(all_triples, uris[0]) is None
# grounding → question
assert derived_from(all_triples, uris[1]) == uris[0]
# exploration → grounding
assert derived_from(all_triples, uris[2]) == uris[1]
# focus → exploration
assert derived_from(all_triples, uris[3]) == uris[2]
# synthesis → focus
assert derived_from(all_triples, uris[4]) == uris[3]
@pytest.mark.asyncio
async def test_question_event_carries_query_text(self):
"""The question event should contain the original query string."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
t = find_triple(q_triples, TG_QUERY, q_uri)
assert t is not None
assert t.o.value == "What is quantum computing?"
@pytest.mark.asyncio
async def test_grounding_carries_concepts(self):
"""The grounding event should list extracted concepts."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
gnd_uri = events[1]["explain_id"]
gnd_triples = events[1]["triples"]
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
concept_values = {t.o.value for t in concepts}
assert "quantum computing" in concept_values
assert "physics" in concept_values
@pytest.mark.asyncio
async def test_exploration_has_edge_count(self):
"""The exploration event should report how many edges were found."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
exp_uri = events[2]["explain_id"]
exp_triples = events[2]["triples"]
t = find_triple(exp_triples, TG_EDGE_COUNT, exp_uri)
assert t is not None
# Should be non-zero (we provided 3 edges, label edges filtered)
assert int(t.o.value) > 0
@pytest.mark.asyncio
async def test_focus_has_selected_edges_with_reasoning(self):
"""
The focus event should carry selected edges as quoted triples
with reasoning text.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
foc_uri = events[3]["explain_id"]
foc_triples = events[3]["triples"]
# Should have selected edges
selected = find_triples(foc_triples, TG_SELECTED_EDGE, foc_uri)
assert len(selected) > 0, "Focus should have at least one selected edge"
# Each edge selection should have a quoted triple
edge_t = find_triples(foc_triples, TG_EDGE)
assert len(edge_t) > 0, "Focus should have tg:edge with quoted triples"
for t in edge_t:
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
# Should have reasoning
reasoning = find_triples(foc_triples, TG_REASONING)
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
reasoning_texts = {t.o.value for t in reasoning}
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
"""The synthesis event should have tg:Answer type."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
syn_uri = events[4]["explain_id"]
syn_triples = events[4]["triples"]
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
@pytest.mark.asyncio
async def test_query_returns_answer_text(self):
"""query() should still return the synthesised answer."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
result = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self):
"""When parent_uri is provided, question should derive from it."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
parent = "urn:trustgraph:agent:iteration:xyz"
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
parent_uri=parent,
)
q_uri = events[0]["explain_id"]
q_triples = events[0]["triples"]
assert derived_from(q_triples, q_uri) == parent
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
"""query() without explain_callback should return answer normally."""
clients = build_mock_clients()
rag = GraphRag(*clients)
result = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
"""All emitted triples should be in the urn:graph:retrieval graph."""
clients = build_mock_clients()
rag = GraphRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
for event in events:
for t in event["triples"]:
assert t.g == "urn:graph:retrieval", (
f"Triple {t.s.iri} {t.p.iri} should be in "
f"urn:graph:retrieval, got {t.g}"
)

View file

@ -28,7 +28,7 @@ class TestRequestTranslation:
}
# Translate to Pulsar
pulsar_msg = translator.to_pulsar(api_data)
pulsar_msg = translator.decode(api_data)
assert pulsar_msg.operation == "schema-selection"
assert pulsar_msg.sample == "test data sample"
@ -46,7 +46,7 @@ class TestRequestTranslation:
"options": {"delimiter": ","}
}
pulsar_msg = translator.to_pulsar(api_data)
pulsar_msg = translator.decode(api_data)
assert pulsar_msg.operation == "generate-descriptor"
assert pulsar_msg.sample == "csv data"
@ -70,7 +70,7 @@ class TestResponseTranslation:
)
# Translate to API format
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "schema-selection"
assert api_data["schema-matches"] == ["products", "inventory", "catalog"]
@ -86,7 +86,7 @@ class TestResponseTranslation:
error=None
)
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "schema-selection"
assert api_data["schema-matches"] == []
@ -103,7 +103,7 @@ class TestResponseTranslation:
error=None
)
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "detect-type"
assert api_data["detected-type"] == "xml"
@ -123,7 +123,7 @@ class TestResponseTranslation:
)
)
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "schema-selection"
# Error objects are typically handled separately by the gateway
@ -146,7 +146,7 @@ class TestResponseTranslation:
error=None
)
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "diagnose"
assert api_data["detected-type"] == "csv"
@ -165,7 +165,7 @@ class TestResponseTranslation:
error=None
)
api_data, is_final = translator.from_response_with_completion(pulsar_response)
api_data, is_final = translator.encode_with_completion(pulsar_response)
assert is_final is True # Structured-diag responses are always final
assert api_data["operation"] == "schema-selection"