mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
Merge remote-tracking branch 'origin/master' into ts-port
This commit is contained in:
commit
f4d6e49217
270 changed files with 19608 additions and 4096 deletions
|
|
@ -87,10 +87,11 @@ def sample_message_data():
|
|||
"history": []
|
||||
},
|
||||
"AgentResponse": {
|
||||
"answer": "Machine learning is a subset of AI.",
|
||||
"chunk_type": "answer",
|
||||
"content": "Machine learning is a subset of AI.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
"error": None,
|
||||
"thought": "I need to provide information about machine learning.",
|
||||
"observation": None
|
||||
},
|
||||
"Metadata": {
|
||||
"id": "test-doc-123",
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert request.user == "test_user"
|
||||
assert request.collection == "test_collection"
|
||||
|
||||
def test_request_translator_to_pulsar(self):
|
||||
def test_request_translator_decode(self):
|
||||
"""Test request translator converts dict to Pulsar schema"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
"collection": "custom_collection"
|
||||
}
|
||||
|
||||
result = translator.to_pulsar(data)
|
||||
result = translator.decode(data)
|
||||
|
||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vector == [0.1, 0.2, 0.3, 0.4]
|
||||
|
|
@ -57,7 +57,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert result.user == "custom_user"
|
||||
assert result.collection == "custom_collection"
|
||||
|
||||
def test_request_translator_to_pulsar_with_defaults(self):
|
||||
def test_request_translator_decode_with_defaults(self):
|
||||
"""Test request translator uses correct defaults"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
# No limit, user, or collection provided
|
||||
}
|
||||
|
||||
result = translator.to_pulsar(data)
|
||||
result = translator.decode(data)
|
||||
|
||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vector == [0.1, 0.2]
|
||||
|
|
@ -74,7 +74,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert result.user == "trustgraph" # Default
|
||||
assert result.collection == "default" # Default
|
||||
|
||||
def test_request_translator_from_pulsar(self):
|
||||
def test_request_translator_encode(self):
|
||||
"""Test request translator converts Pulsar schema to dict"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
collection="test_collection"
|
||||
)
|
||||
|
||||
result = translator.from_pulsar(request)
|
||||
result = translator.encode(request)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["vector"] == [0.5, 0.6]
|
||||
|
|
@ -134,7 +134,7 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
assert response.error == error
|
||||
assert response.chunks == []
|
||||
|
||||
def test_response_translator_from_pulsar_with_chunks(self):
|
||||
def test_response_translator_encode_with_chunks(self):
|
||||
"""Test response translator converts Pulsar schema with chunks to dict"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
|
@ -147,7 +147,7 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
]
|
||||
)
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
|
|
@ -155,7 +155,7 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
|
||||
assert result["chunks"][0]["score"] == 0.95
|
||||
|
||||
def test_response_translator_from_pulsar_with_empty_chunks(self):
|
||||
def test_response_translator_encode_with_empty_chunks(self):
|
||||
"""Test response translator handles empty chunks list"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
|
@ -164,25 +164,25 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
chunks=[]
|
||||
)
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
assert result["chunks"] == []
|
||||
|
||||
def test_response_translator_from_pulsar_with_none_chunks(self):
|
||||
def test_response_translator_encode_with_none_chunks(self):
|
||||
"""Test response translator handles None chunks"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
response = MagicMock()
|
||||
response.chunks = None
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" not in result or result.get("chunks") is None
|
||||
|
||||
def test_response_translator_from_response_with_completion(self):
|
||||
def test_response_translator_encode_with_completion(self):
|
||||
"""Test response translator with completion flag"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
]
|
||||
)
|
||||
|
||||
result, is_final = translator.from_response_with_completion(response)
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
|
|
@ -202,12 +202,12 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
assert result["chunks"][0]["chunk_id"] == "chunk1"
|
||||
assert is_final is True # Document embeddings responses are always final
|
||||
|
||||
def test_response_translator_to_pulsar_not_implemented(self):
|
||||
"""Test that to_pulsar raises NotImplementedError for responses"""
|
||||
def test_response_translator_decode_not_implemented(self):
|
||||
"""Test that decode raises NotImplementedError for responses"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
|
||||
translator.decode({"chunks": [{"chunk_id": "test", "score": 0.9}]})
|
||||
|
||||
|
||||
class TestDocumentEmbeddingsMessageCompatibility:
|
||||
|
|
@ -225,7 +225,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
|
||||
# Convert to Pulsar request
|
||||
req_translator = DocumentEmbeddingsRequestTranslator()
|
||||
pulsar_request = req_translator.to_pulsar(request_data)
|
||||
pulsar_request = req_translator.decode(request_data)
|
||||
|
||||
# Simulate service processing and creating response
|
||||
response = DocumentEmbeddingsResponse(
|
||||
|
|
@ -238,7 +238,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
|
||||
# Convert response back to dict
|
||||
resp_translator = DocumentEmbeddingsResponseTranslator()
|
||||
response_data = resp_translator.from_pulsar(response)
|
||||
response_data = resp_translator.encode(response)
|
||||
|
||||
# Verify data integrity
|
||||
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
|
||||
|
|
@ -261,7 +261,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
|
||||
# Convert response to dict
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
response_data = translator.from_pulsar(response)
|
||||
response_data = translator.encode(response)
|
||||
|
||||
# Verify error handling
|
||||
assert isinstance(response_data, dict)
|
||||
|
|
|
|||
|
|
@ -212,10 +212,11 @@ class TestAgentMessageContracts:
|
|||
|
||||
# Test required fields
|
||||
response = AgentResponse(**response_data)
|
||||
assert hasattr(response, 'answer')
|
||||
assert hasattr(response, 'chunk_type')
|
||||
assert hasattr(response, 'content')
|
||||
assert hasattr(response, 'end_of_message')
|
||||
assert hasattr(response, 'end_of_dialog')
|
||||
assert hasattr(response, 'error')
|
||||
assert hasattr(response, 'thought')
|
||||
assert hasattr(response, 'observation')
|
||||
|
||||
def test_agent_step_schema_contract(self):
|
||||
"""Test AgentStep schema contract"""
|
||||
|
|
|
|||
177
tests/contract/test_orchestrator_contracts.py
Normal file
177
tests/contract/test_orchestrator_contracts.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Contract tests for orchestrator message schemas.
|
||||
|
||||
Verifies that AgentRequest/AgentStep with orchestration fields
|
||||
serialise and deserialise correctly through the Pulsar schema layer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep, PlanStep
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestOrchestrationFieldContracts:
|
||||
"""Contract tests for orchestration fields on AgentRequest."""
|
||||
|
||||
def test_agent_request_orchestration_fields_roundtrip(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
pattern="react",
|
||||
task_type="research",
|
||||
framing="Focus on accuracy",
|
||||
conversation_id="conv-456",
|
||||
)
|
||||
|
||||
assert req.correlation_id == "corr-123"
|
||||
assert req.parent_session_id == "parent-sess"
|
||||
assert req.subagent_goal == "What is X?"
|
||||
assert req.expected_siblings == 4
|
||||
assert req.pattern == "react"
|
||||
assert req.task_type == "research"
|
||||
assert req.framing == "Focus on accuracy"
|
||||
assert req.conversation_id == "conv-456"
|
||||
|
||||
def test_agent_request_orchestration_fields_default_empty(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
)
|
||||
|
||||
assert req.correlation_id == ""
|
||||
assert req.parent_session_id == ""
|
||||
assert req.subagent_goal == ""
|
||||
assert req.expected_siblings == 0
|
||||
assert req.pattern == ""
|
||||
assert req.task_type == ""
|
||||
assert req.framing == ""
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSubagentCompletionStepContract:
|
||||
"""Contract tests for subagent-completion step type."""
|
||||
|
||||
def test_subagent_completion_step_fields(self):
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation="The answer text",
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
|
||||
assert step.step_type == "subagent-completion"
|
||||
assert step.observation == "The answer text"
|
||||
assert step.thought == "Subagent completed"
|
||||
assert step.action == "complete"
|
||||
|
||||
def test_subagent_completion_in_request_history(self):
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation="answer",
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
req = AgentRequest(
|
||||
question="goal",
|
||||
user="testuser",
|
||||
correlation_id="corr-123",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
assert len(req.history) == 1
|
||||
assert req.history[0].step_type == "subagent-completion"
|
||||
assert req.history[0].observation == "answer"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSynthesisStepContract:
|
||||
"""Contract tests for synthesis step type with subagent_results."""
|
||||
|
||||
def test_synthesis_step_with_results(self):
|
||||
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
|
||||
step = AgentStep(
|
||||
thought="All subagents completed",
|
||||
action="aggregate",
|
||||
arguments={},
|
||||
observation=json.dumps(results),
|
||||
step_type="synthesise",
|
||||
subagent_results=results,
|
||||
)
|
||||
|
||||
assert step.step_type == "synthesise"
|
||||
assert step.subagent_results == results
|
||||
assert json.loads(step.observation) == results
|
||||
|
||||
def test_synthesis_request_matches_supervisor_expectations(self):
|
||||
"""The synthesis request built by the aggregator must be
|
||||
recognisable by SupervisorPattern._synthesise()."""
|
||||
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
|
||||
step = AgentStep(
|
||||
thought="All subagents completed",
|
||||
action="aggregate",
|
||||
arguments={},
|
||||
observation=json.dumps(results),
|
||||
step_type="synthesise",
|
||||
subagent_results=results,
|
||||
)
|
||||
|
||||
req = AgentRequest(
|
||||
question="Original question",
|
||||
user="testuser",
|
||||
pattern="supervisor",
|
||||
correlation_id="",
|
||||
session_id="parent-sess",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
# SupervisorPattern checks for step_type='synthesise' with
|
||||
# subagent_results
|
||||
has_results = bool(
|
||||
req.history
|
||||
and any(
|
||||
getattr(h, 'step_type', '') == 'synthesise'
|
||||
and getattr(h, 'subagent_results', None)
|
||||
for h in req.history
|
||||
)
|
||||
)
|
||||
assert has_results
|
||||
|
||||
# Pattern must be supervisor
|
||||
assert req.pattern == "supervisor"
|
||||
|
||||
# Correlation ID must be empty (not re-intercepted)
|
||||
assert req.correlation_id == ""
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestPlanStepContract:
|
||||
"""Contract tests for plan steps in history."""
|
||||
|
||||
def test_plan_step_in_history(self):
|
||||
plan = [
|
||||
PlanStep(goal="Step 1", tool_hint="knowledge-query",
|
||||
depends_on=[], status="completed", result="done"),
|
||||
PlanStep(goal="Step 2", tool_hint="",
|
||||
depends_on=[0], status="pending", result=""),
|
||||
]
|
||||
step = AgentStep(
|
||||
thought="Created plan",
|
||||
action="plan",
|
||||
step_type="plan",
|
||||
plan=plan,
|
||||
)
|
||||
|
||||
assert step.step_type == "plan"
|
||||
assert len(step.plan) == 2
|
||||
assert step.plan[0].goal == "Step 1"
|
||||
assert step.plan[0].status == "completed"
|
||||
assert step.plan[1].depends_on == [0]
|
||||
129
tests/contract/test_provenance_wire_format.py
Normal file
129
tests/contract/test_provenance_wire_format.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""
|
||||
Contract tests for provenance triple wire format — verifies that triples
|
||||
built by the provenance library can be parsed by the explainability API
|
||||
through the wire format conversion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
from trustgraph.provenance import (
|
||||
agent_decomposition_triples,
|
||||
agent_finding_triples,
|
||||
agent_plan_triples,
|
||||
agent_step_result_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.api.explainability import (
|
||||
ExplainEntity,
|
||||
Decomposition,
|
||||
Finding,
|
||||
Plan,
|
||||
StepResult,
|
||||
Synthesis,
|
||||
wire_triples_to_tuples,
|
||||
)
|
||||
|
||||
|
||||
def _triples_to_wire(triples):
|
||||
"""Convert provenance Triple objects to the wire format dicts
|
||||
that the gateway/socket client would produce."""
|
||||
wire = []
|
||||
for t in triples:
|
||||
entry = {
|
||||
"s": _term_to_wire(t.s),
|
||||
"p": _term_to_wire(t.p),
|
||||
"o": _term_to_wire(t.o),
|
||||
}
|
||||
wire.append(entry)
|
||||
return wire
|
||||
|
||||
|
||||
def _term_to_wire(term):
|
||||
"""Convert a Term to wire format dict."""
|
||||
if term.type == IRI:
|
||||
return {"t": "i", "i": term.iri}
|
||||
elif term.type == LITERAL:
|
||||
return {"t": "l", "v": term.value}
|
||||
return {"t": "l", "v": str(term)}
|
||||
|
||||
|
||||
def _roundtrip(triples, uri):
|
||||
"""Convert triples through wire format and parse via from_triples."""
|
||||
wire = _triples_to_wire(triples)
|
||||
tuples = wire_triples_to_tuples(wire)
|
||||
return ExplainEntity.from_triples(uri, tuples)
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestDecompositionWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session",
|
||||
["What is X?", "What is Y?"],
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:decompose")
|
||||
|
||||
assert isinstance(entity, Decomposition)
|
||||
assert set(entity.goals) == {"What is X?", "What is Y?"}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestFindingWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
document_id="urn:doc/finding",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:finding")
|
||||
|
||||
assert isinstance(entity, Finding)
|
||||
assert entity.goal == "What is X?"
|
||||
assert entity.document == "urn:doc/finding"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestPlanWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session",
|
||||
["Step 1", "Step 2", "Step 3"],
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:plan")
|
||||
|
||||
assert isinstance(entity, Plan)
|
||||
assert set(entity.steps) == {"Step 1", "Step 2", "Step 3"}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestStepResultWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
document_id="urn:doc/step",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:step")
|
||||
|
||||
assert isinstance(entity, StepResult)
|
||||
assert entity.step == "Define X"
|
||||
assert entity.document == "urn:doc/step"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSynthesisWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
document_id="urn:doc/synthesis",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:synthesis")
|
||||
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.document == "urn:doc/synthesis"
|
||||
|
|
@ -33,7 +33,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_session=True"
|
||||
|
|
@ -57,7 +57,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_session=False"
|
||||
|
|
@ -80,7 +80,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
|
|
@ -103,7 +103,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
|
||||
|
|
@ -125,7 +125,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_session=True"
|
||||
|
|
@ -147,7 +147,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
|
||||
|
|
@ -168,7 +168,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_stream=False"
|
||||
|
|
@ -188,20 +188,18 @@ class TestAgentTranslatorCompletionFlags:
|
|||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
answer="4",
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None,
|
||||
chunk_type="answer",
|
||||
content="4",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_dialog=True"
|
||||
assert response_dict["answer"] == "4"
|
||||
assert response_dict["content"] == "4"
|
||||
assert response_dict["end_of_dialog"] is True
|
||||
|
||||
def test_agent_translator_is_final_with_end_of_dialog_false(self):
|
||||
|
|
@ -212,44 +210,20 @@ class TestAgentTranslatorCompletionFlags:
|
|||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought="I need to solve this.",
|
||||
observation=None,
|
||||
chunk_type="thought",
|
||||
content="I need to solve this.",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_dialog=False"
|
||||
assert response_dict["thought"] == "I need to solve this."
|
||||
assert response_dict["content"] == "I need to solve this."
|
||||
assert response_dict["end_of_dialog"] is False
|
||||
|
||||
def test_agent_translator_is_final_fallback_with_answer(self):
|
||||
"""
|
||||
Test that AgentResponseTranslator returns is_final=True
|
||||
when answer is present (fallback for legacy responses).
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
# Legacy response without end_of_dialog flag
|
||||
response = AgentResponse(
|
||||
answer="4",
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when answer is present (legacy fallback)"
|
||||
assert response_dict["answer"] == "4"
|
||||
|
||||
def test_agent_translator_intermediate_message_is_not_final(self):
|
||||
"""
|
||||
Test that intermediate messages (thought/observation) return is_final=False.
|
||||
|
|
@ -259,32 +233,28 @@ class TestAgentTranslatorCompletionFlags:
|
|||
|
||||
# Test thought message
|
||||
thought_response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought="Processing...",
|
||||
observation=None,
|
||||
chunk_type="thought",
|
||||
content="Processing...",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
# Act
|
||||
thought_dict, thought_is_final = translator.from_response_with_completion(thought_response)
|
||||
thought_dict, thought_is_final = translator.encode_with_completion(thought_response)
|
||||
|
||||
# Assert
|
||||
assert thought_is_final is False, "Thought message must not be final"
|
||||
|
||||
# Test observation message
|
||||
observation_response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation="Result found",
|
||||
chunk_type="observation",
|
||||
content="Result found",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
# Act
|
||||
obs_dict, obs_is_final = translator.from_response_with_completion(observation_response)
|
||||
obs_dict, obs_is_final = translator.encode_with_completion(observation_response)
|
||||
|
||||
# Assert
|
||||
assert obs_is_final is False, "Observation message must not be final"
|
||||
|
|
@ -302,14 +272,10 @@ class TestAgentTranslatorCompletionFlags:
|
|||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "Streaming format must use end_of_dialog for is_final"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing.
|
|||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, ANY, patch
|
||||
|
||||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
|
|
@ -187,7 +187,7 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
|
|||
|
||||
# Verify tool was executed
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
|
|
@ -272,7 +272,7 @@ Args: {{
|
|||
|
||||
# Verify correct service was called
|
||||
if tool_name == "knowledge_query":
|
||||
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default")
|
||||
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default", explain_callback=ANY, parent_uri=ANY)
|
||||
elif tool_name == "text_completion":
|
||||
mock_flow_context("prompt-request").question.assert_called()
|
||||
|
||||
|
|
@ -726,7 +726,7 @@ Final Answer: {
|
|||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default")
|
||||
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_with_custom_collection(self, mock_flow_context):
|
||||
|
|
@ -739,7 +739,7 @@ Final Answer: {
|
|||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_with_none_collection(self, mock_flow_context):
|
||||
|
|
@ -752,7 +752,7 @@ Final Answer: {
|
|||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default")
|
||||
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context):
|
||||
|
|
@ -810,7 +810,7 @@ Args: {
|
|||
|
||||
# Verify the custom collection was used
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers")
|
||||
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_multiple_collections(self, mock_flow_context):
|
||||
|
|
@ -840,4 +840,4 @@ Args: {
|
|||
|
||||
# Verify correct collection was used
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection)
|
||||
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection, explain_callback=ANY, parent_uri=ANY)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class TestAgentServiceNonStreaming:
|
|||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
|
||||
# Mock react to call think and observe callbacks
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming):
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
await think("I need to solve this.", is_final=True)
|
||||
await observe("The answer is 4.", is_final=True)
|
||||
return Final(thought="Final answer", final="4")
|
||||
|
|
@ -76,22 +76,33 @@ class TestAgentServiceNonStreaming:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: should have 3 responses (thought, observation, answer)
|
||||
assert len(sent_responses) == 3, f"Expected 3 responses, got {len(sent_responses)}"
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session, iteration, observation, and final
|
||||
assert len(explain_responses) >= 1, "Expected at least 1 explain event"
|
||||
|
||||
# Should have 3 content responses (thought, observation, answer)
|
||||
assert len(content_responses) == 3, f"Expected 3 content responses, got {len(content_responses)}"
|
||||
|
||||
# Check thought message
|
||||
thought_response = sent_responses[0]
|
||||
thought_response = content_responses[0]
|
||||
assert isinstance(thought_response, AgentResponse)
|
||||
assert thought_response.thought == "I need to solve this."
|
||||
assert thought_response.answer is None
|
||||
assert thought_response.chunk_type == "thought"
|
||||
assert thought_response.content == "I need to solve this."
|
||||
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
|
||||
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
|
||||
|
||||
# Check observation message
|
||||
observation_response = sent_responses[1]
|
||||
observation_response = content_responses[1]
|
||||
assert isinstance(observation_response, AgentResponse)
|
||||
assert observation_response.observation == "The answer is 4."
|
||||
assert observation_response.answer is None
|
||||
assert observation_response.chunk_type == "observation"
|
||||
assert observation_response.content == "The answer is 4."
|
||||
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
|
||||
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
|
||||
|
||||
|
|
@ -120,7 +131,7 @@ class TestAgentServiceNonStreaming:
|
|||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
|
||||
# Mock react to return Final directly
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming):
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
return Final(thought="Final answer", final="4")
|
||||
|
||||
mock_agent_instance.react = mock_react
|
||||
|
|
@ -155,15 +166,25 @@ class TestAgentServiceNonStreaming:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: should have 1 response (final answer)
|
||||
assert len(sent_responses) == 1, f"Expected 1 response, got {len(sent_responses)}"
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session and final
|
||||
assert len(explain_responses) >= 1, "Expected at least 1 explain event"
|
||||
|
||||
# Should have 1 content response (final answer)
|
||||
assert len(content_responses) == 1, f"Expected 1 content response, got {len(content_responses)}"
|
||||
|
||||
# Check final answer message
|
||||
answer_response = sent_responses[0]
|
||||
answer_response = content_responses[0]
|
||||
assert isinstance(answer_response, AgentResponse)
|
||||
assert answer_response.answer == "4"
|
||||
assert answer_response.thought is None
|
||||
assert answer_response.observation is None
|
||||
assert answer_response.chunk_type == "answer"
|
||||
assert answer_response.content == "4"
|
||||
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
|
||||
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"
|
||||
|
||||
|
|
|
|||
216
tests/unit/test_agent/test_aggregator.py
Normal file
216
tests/unit/test_agent/test_aggregator.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""
|
||||
Unit tests for the Aggregator — tracks fan-out correlations and triggers
|
||||
synthesis when all subagents complete.
|
||||
"""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep
|
||||
|
||||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(question="Test question", user="testuser",
|
||||
collection="default", streaming=False,
|
||||
session_id="parent-session", task_type="research",
|
||||
framing="test framing", conversation_id="conv-1"):
|
||||
return AgentRequest(
|
||||
question=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
framing=framing,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
|
||||
class TestRegisterFanout:
|
||||
|
||||
def test_stores_correlation_entry(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 3)
|
||||
|
||||
assert "corr-1" in agg.correlations
|
||||
entry = agg.correlations["corr-1"]
|
||||
assert entry["parent_session_id"] == "parent-1"
|
||||
assert entry["expected"] == 3
|
||||
assert entry["results"] == {}
|
||||
|
||||
def test_stores_request_template(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
|
||||
entry = agg.correlations["corr-1"]
|
||||
assert entry["request_template"] is template
|
||||
|
||||
def test_records_creation_time(self):
|
||||
agg = Aggregator()
|
||||
before = time.time()
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
after = time.time()
|
||||
|
||||
created = agg.correlations["corr-1"]["created_at"]
|
||||
assert before <= created <= after
|
||||
|
||||
|
||||
class TestRecordCompletion:
|
||||
|
||||
def test_returns_false_until_all_done(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 3)
|
||||
|
||||
assert agg.record_completion("corr-1", "goal-a", "answer-a") is False
|
||||
assert agg.record_completion("corr-1", "goal-b", "answer-b") is False
|
||||
assert agg.record_completion("corr-1", "goal-c", "answer-c") is True
|
||||
|
||||
def test_returns_none_for_unknown_correlation(self):
|
||||
agg = Aggregator()
|
||||
result = agg.record_completion("unknown", "goal", "answer")
|
||||
assert result is None
|
||||
|
||||
def test_stores_results_by_goal(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
results = agg.correlations["corr-1"]["results"]
|
||||
assert results["goal-a"] == "answer-a"
|
||||
assert results["goal-b"] == "answer-b"
|
||||
|
||||
def test_single_subagent(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 1)
|
||||
|
||||
assert agg.record_completion("corr-1", "goal-a", "answer") is True
|
||||
|
||||
|
||||
class TestGetOriginalRequest:
|
||||
|
||||
def test_peeks_without_consuming(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
|
||||
result = agg.get_original_request("corr-1")
|
||||
assert result is template
|
||||
# Entry still exists
|
||||
assert "corr-1" in agg.correlations
|
||||
|
||||
def test_returns_none_for_unknown(self):
|
||||
agg = Aggregator()
|
||||
assert agg.get_original_request("unknown") is None
|
||||
|
||||
|
||||
class TestBuildSynthesisRequest:
|
||||
|
||||
def test_builds_correct_request(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request(
|
||||
question="Original question",
|
||||
streaming=True,
|
||||
task_type="risk-assessment",
|
||||
framing="Assess risks",
|
||||
)
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
assert req.question == "Original question"
|
||||
assert req.pattern == "supervisor"
|
||||
assert req.session_id == "parent-1"
|
||||
assert req.correlation_id == "" # Must be empty
|
||||
assert req.streaming == True
|
||||
assert req.task_type == "risk-assessment"
|
||||
assert req.framing == "Assess risks"
|
||||
|
||||
def test_synthesis_step_in_history(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# Last history step should be the synthesis step
|
||||
assert len(req.history) >= 1
|
||||
synth_step = req.history[-1]
|
||||
assert synth_step.step_type == "synthesise"
|
||||
assert synth_step.subagent_results == {
|
||||
"goal-a": "answer-a",
|
||||
"goal-b": "answer-b",
|
||||
}
|
||||
|
||||
def test_consumes_correlation_entry(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 1,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
|
||||
agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# Entry should be removed
|
||||
assert "corr-1" not in agg.correlations
|
||||
|
||||
def test_raises_for_unknown_correlation(self):
|
||||
agg = Aggregator()
|
||||
with pytest.raises(RuntimeError, match="No results"):
|
||||
agg.build_synthesis_request(
|
||||
"unknown", "question", "user", "default",
|
||||
)
|
||||
|
||||
|
||||
class TestCleanupStale:
|
||||
|
||||
def test_removes_entries_older_than_timeout(self):
|
||||
agg = Aggregator(timeout=1)
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
# Backdate the creation time
|
||||
agg.correlations["corr-1"]["created_at"] = time.time() - 2
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert "corr-1" in stale
|
||||
assert "corr-1" not in agg.correlations
|
||||
|
||||
def test_keeps_recent_entries(self):
|
||||
agg = Aggregator(timeout=300)
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert stale == []
|
||||
assert "corr-1" in agg.correlations
|
||||
|
||||
def test_mixed_stale_and_fresh(self):
|
||||
agg = Aggregator(timeout=1)
|
||||
agg.register_fanout("stale", "parent-1", 2)
|
||||
agg.register_fanout("fresh", "parent-2", 2)
|
||||
|
||||
agg.correlations["stale"]["created_at"] = time.time() - 2
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert "stale" in stale
|
||||
assert "stale" not in agg.correlations
|
||||
assert "fresh" in agg.correlations
|
||||
122
tests/unit/test_agent/test_callback_message_id.py
Normal file
122
tests/unit/test_agent/test_callback_message_id.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
Tests that streaming callbacks set message_id on AgentResponse.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.orchestrator.pattern_base import PatternBase
|
||||
from trustgraph.schema import AgentResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pattern():
|
||||
processor = MagicMock()
|
||||
return PatternBase(processor)
|
||||
|
||||
|
||||
class TestThinkCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_think_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/thought"
|
||||
think = pattern.make_think_callback(capture, streaming=True, message_id=msg_id)
|
||||
await think("hello", is_final=False)
|
||||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "thought"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_think_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/thought"
|
||||
think = pattern.make_think_callback(capture, streaming=False, message_id=msg_id)
|
||||
await think("hello")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].end_of_message is True
|
||||
|
||||
|
||||
class TestObserveCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_observe_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/observation"
|
||||
observe = pattern.make_observe_callback(capture, streaming=True, message_id=msg_id)
|
||||
await observe("result", is_final=True)
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "observation"
|
||||
|
||||
|
||||
class TestAnswerCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_answer_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
answer = pattern.make_answer_callback(capture, streaming=True, message_id=msg_id)
|
||||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_default(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
answer = pattern.make_answer_callback(capture, streaming=True)
|
||||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == ""
|
||||
|
||||
|
||||
class TestSendFinalResponseMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_final_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
await pattern.send_final_response(
|
||||
capture, streaming=True, answer_text="answer",
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
# Should get content chunk + end-of-dialog marker
|
||||
assert all(r.message_id == msg_id for r in responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_final_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
await pattern.send_final_response(
|
||||
capture, streaming=False, answer_text="answer",
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].end_of_dialog is True
|
||||
174
tests/unit/test_agent/test_completion_dispatch.py
Normal file
174
tests/unit/test_agent/test_completion_dispatch.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
Unit tests for completion dispatch — verifies that agent_request() in the
|
||||
orchestrator service correctly intercepts subagent completion messages and
|
||||
routes them to _handle_subagent_completion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep
|
||||
|
||||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return AgentRequest(**defaults)
|
||||
|
||||
|
||||
def _make_completion_request(correlation_id, goal, answer):
|
||||
"""Build a completion request as emit_subagent_completion would."""
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation=answer,
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
return _make_request(
|
||||
correlation_id=correlation_id,
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal=goal,
|
||||
expected_siblings=2,
|
||||
history=[step],
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionDetection:
|
||||
"""Test that completion messages are correctly identified."""
|
||||
|
||||
def test_is_completion_when_correlation_id_and_step_type(self):
|
||||
req = _make_completion_request("corr-1", "goal-a", "answer-a")
|
||||
|
||||
has_correlation = bool(getattr(req, 'correlation_id', ''))
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in req.history
|
||||
)
|
||||
|
||||
assert has_correlation
|
||||
assert is_completion
|
||||
|
||||
def test_not_completion_without_correlation_id(self):
|
||||
step = AgentStep(
|
||||
step_type="subagent-completion",
|
||||
observation="answer",
|
||||
)
|
||||
req = _make_request(
|
||||
correlation_id="",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
has_correlation = bool(getattr(req, 'correlation_id', ''))
|
||||
assert not has_correlation
|
||||
|
||||
def test_not_completion_without_step_type(self):
|
||||
step = AgentStep(
|
||||
step_type="react",
|
||||
observation="answer",
|
||||
)
|
||||
req = _make_request(
|
||||
correlation_id="corr-1",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in req.history
|
||||
)
|
||||
assert not is_completion
|
||||
|
||||
def test_not_completion_with_empty_history(self):
|
||||
req = _make_request(
|
||||
correlation_id="corr-1",
|
||||
history=[],
|
||||
)
|
||||
assert not req.history
|
||||
|
||||
|
||||
class TestAggregatorIntegration:
|
||||
"""Test the aggregator flow as used by _handle_subagent_completion."""
|
||||
|
||||
def test_full_completion_flow(self):
|
||||
"""Simulates the flow: register, record completions, build synthesis."""
|
||||
agg = Aggregator()
|
||||
template = _make_request(
|
||||
question="Original question",
|
||||
streaming=True,
|
||||
task_type="risk-assessment",
|
||||
framing="Assess risks",
|
||||
session_id="parent-sess",
|
||||
)
|
||||
|
||||
# Register fan-out
|
||||
agg.register_fanout("corr-1", "parent-sess", 2,
|
||||
request_template=template)
|
||||
|
||||
# First completion — not all done
|
||||
all_done = agg.record_completion(
|
||||
"corr-1", "goal-a", "answer-a",
|
||||
)
|
||||
assert all_done is False
|
||||
|
||||
# Second completion — all done
|
||||
all_done = agg.record_completion(
|
||||
"corr-1", "goal-b", "answer-b",
|
||||
)
|
||||
assert all_done is True
|
||||
|
||||
# Peek at template
|
||||
peeked = agg.get_original_request("corr-1")
|
||||
assert peeked.question == "Original question"
|
||||
|
||||
# Build synthesis request
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
# Verify synthesis request
|
||||
assert synth.pattern == "supervisor"
|
||||
assert synth.correlation_id == ""
|
||||
assert synth.session_id == "parent-sess"
|
||||
assert synth.streaming is True
|
||||
|
||||
# Verify synthesis history has results
|
||||
synth_steps = [
|
||||
s for s in synth.history
|
||||
if getattr(s, 'step_type', '') == 'synthesise'
|
||||
]
|
||||
assert len(synth_steps) == 1
|
||||
assert synth_steps[0].subagent_results == {
|
||||
"goal-a": "answer-a",
|
||||
"goal-b": "answer-b",
|
||||
}
|
||||
|
||||
def test_synthesis_request_not_detected_as_completion(self):
|
||||
"""The synthesis request must not be intercepted as a completion."""
|
||||
agg = Aggregator()
|
||||
template = _make_request(session_id="parent-sess")
|
||||
agg.register_fanout("corr-1", "parent-sess", 1,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal", "answer")
|
||||
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# correlation_id must be empty so it's not intercepted
|
||||
assert synth.correlation_id == ""
|
||||
|
||||
# Even if we check for completion step, shouldn't match
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in synth.history
|
||||
)
|
||||
assert not is_completion
|
||||
177
tests/unit/test_agent/test_explainability_parsing.py
Normal file
177
tests/unit/test_agent/test_explainability_parsing.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Unit tests for explainability API parsing — verifies that from_triples()
|
||||
correctly dispatches and parses the new orchestrator entity types.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.api.explainability import (
|
||||
ExplainEntity,
|
||||
Decomposition,
|
||||
Finding,
|
||||
Plan,
|
||||
StepResult,
|
||||
Synthesis,
|
||||
Analysis,
|
||||
Observation,
|
||||
Conclusion,
|
||||
TG_DECOMPOSITION,
|
||||
TG_FINDING,
|
||||
TG_PLAN_TYPE,
|
||||
TG_STEP_RESULT,
|
||||
TG_SYNTHESIS,
|
||||
TG_ANSWER_TYPE,
|
||||
TG_OBSERVATION_TYPE,
|
||||
TG_TOOL_USE,
|
||||
TG_ANALYSIS,
|
||||
TG_CONCLUSION,
|
||||
TG_DOCUMENT,
|
||||
TG_SUBAGENT_GOAL,
|
||||
TG_PLAN_STEP,
|
||||
RDF_TYPE,
|
||||
)
|
||||
|
||||
PROV_ENTITY = "http://www.w3.org/ns/prov#Entity"
|
||||
|
||||
|
||||
def _make_triples(uri, types, extras=None):
|
||||
"""Build a list of (s, p, o) tuples for testing."""
|
||||
triples = [(uri, RDF_TYPE, t) for t in types]
|
||||
if extras:
|
||||
triples.extend((uri, p, o) for p, o in extras)
|
||||
return triples
|
||||
|
||||
|
||||
class TestFromTriplesDispatch:
|
||||
|
||||
def test_dispatches_decomposition(self):
|
||||
triples = _make_triples("urn:d", [PROV_ENTITY, TG_DECOMPOSITION])
|
||||
entity = ExplainEntity.from_triples("urn:d", triples)
|
||||
assert isinstance(entity, Decomposition)
|
||||
|
||||
def test_dispatches_finding(self):
|
||||
triples = _make_triples("urn:f",
|
||||
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:f", triples)
|
||||
assert isinstance(entity, Finding)
|
||||
|
||||
def test_dispatches_plan(self):
|
||||
triples = _make_triples("urn:p", [PROV_ENTITY, TG_PLAN_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:p", triples)
|
||||
assert isinstance(entity, Plan)
|
||||
|
||||
def test_dispatches_step_result(self):
|
||||
triples = _make_triples("urn:sr",
|
||||
[PROV_ENTITY, TG_STEP_RESULT, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:sr", triples)
|
||||
assert isinstance(entity, StepResult)
|
||||
|
||||
def test_dispatches_synthesis(self):
|
||||
triples = _make_triples("urn:s",
|
||||
[PROV_ENTITY, TG_SYNTHESIS, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:s", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
|
||||
def test_dispatches_analysis_unchanged(self):
|
||||
triples = _make_triples("urn:a", [PROV_ENTITY, TG_ANALYSIS])
|
||||
entity = ExplainEntity.from_triples("urn:a", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
|
||||
def test_dispatches_analysis_with_tooluse(self):
|
||||
"""Analysis+ToolUse mixin still dispatches to Analysis."""
|
||||
triples = _make_triples("urn:a",
|
||||
[PROV_ENTITY, TG_ANALYSIS, TG_TOOL_USE])
|
||||
entity = ExplainEntity.from_triples("urn:a", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
|
||||
def test_dispatches_observation(self):
|
||||
triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:o", triples)
|
||||
assert isinstance(entity, Observation)
|
||||
|
||||
def test_dispatches_conclusion_unchanged(self):
|
||||
triples = _make_triples("urn:c",
|
||||
[PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:c", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
|
||||
def test_finding_takes_precedence_over_synthesis(self):
|
||||
"""Finding has Answer mixin but should dispatch to Finding, not
|
||||
Synthesis, because Finding is checked first."""
|
||||
triples = _make_triples("urn:f",
|
||||
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:f", triples)
|
||||
assert isinstance(entity, Finding)
|
||||
assert not isinstance(entity, Synthesis)
|
||||
|
||||
|
||||
class TestDecompositionParsing:
|
||||
|
||||
def test_parses_goals(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION], [
|
||||
(TG_SUBAGENT_GOAL, "What is X?"),
|
||||
(TG_SUBAGENT_GOAL, "What is Y?"),
|
||||
])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert set(entity.goals) == {"What is X?", "What is Y?"}
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert entity.entity_type == "decomposition"
|
||||
|
||||
def test_empty_goals(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert entity.goals == []
|
||||
|
||||
|
||||
class TestFindingParsing:
|
||||
|
||||
def test_parses_goal_and_document(self):
|
||||
triples = _make_triples("urn:f", [TG_FINDING, TG_ANSWER_TYPE], [
|
||||
(TG_SUBAGENT_GOAL, "What is X?"),
|
||||
(TG_DOCUMENT, "urn:doc/finding"),
|
||||
])
|
||||
entity = Finding.from_triples("urn:f", triples)
|
||||
assert entity.goal == "What is X?"
|
||||
assert entity.document == "urn:doc/finding"
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:f", [TG_FINDING])
|
||||
entity = Finding.from_triples("urn:f", triples)
|
||||
assert entity.entity_type == "finding"
|
||||
|
||||
|
||||
class TestPlanParsing:
|
||||
|
||||
def test_parses_steps(self):
|
||||
triples = _make_triples("urn:p", [TG_PLAN_TYPE], [
|
||||
(TG_PLAN_STEP, "Define X"),
|
||||
(TG_PLAN_STEP, "Research Y"),
|
||||
(TG_PLAN_STEP, "Analyse Z"),
|
||||
])
|
||||
entity = Plan.from_triples("urn:p", triples)
|
||||
assert set(entity.steps) == {"Define X", "Research Y", "Analyse Z"}
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:p", [TG_PLAN_TYPE])
|
||||
entity = Plan.from_triples("urn:p", triples)
|
||||
assert entity.entity_type == "plan"
|
||||
|
||||
|
||||
class TestStepResultParsing:
|
||||
|
||||
def test_parses_step_and_document(self):
|
||||
triples = _make_triples("urn:sr", [TG_STEP_RESULT, TG_ANSWER_TYPE], [
|
||||
(TG_PLAN_STEP, "Define X"),
|
||||
(TG_DOCUMENT, "urn:doc/step"),
|
||||
])
|
||||
entity = StepResult.from_triples("urn:sr", triples)
|
||||
assert entity.step == "Define X"
|
||||
assert entity.document == "urn:doc/step"
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:sr", [TG_STEP_RESULT])
|
||||
entity = StepResult.from_triples("urn:sr", triples)
|
||||
assert entity.entity_type == "step-result"
|
||||
289
tests/unit/test_agent/test_meta_router.py
Normal file
289
tests/unit/test_agent/test_meta_router.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""
|
||||
Unit tests for the MetaRouter — task type identification and pattern selection.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.orchestrator.meta_router import (
|
||||
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
|
||||
)
|
||||
|
||||
|
||||
def _make_config(patterns=None, task_types=None):
|
||||
"""Build a config dict as the config service would provide."""
|
||||
config = {}
|
||||
if patterns:
|
||||
config["agent-pattern"] = {
|
||||
pid: json.dumps(pdata) for pid, pdata in patterns.items()
|
||||
}
|
||||
if task_types:
|
||||
config["agent-task-type"] = {
|
||||
tid: json.dumps(tdata) for tid, tdata in task_types.items()
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_context(prompt_response):
|
||||
"""Build a mock context that returns a mock prompt client."""
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(return_value=prompt_response)
|
||||
|
||||
def context(service_name):
|
||||
return client
|
||||
|
||||
return context
|
||||
|
||||
|
||||
SAMPLE_PATTERNS = {
|
||||
"react": {"name": "react", "description": "ReAct pattern"},
|
||||
"plan-then-execute": {"name": "plan-then-execute", "description": "Plan pattern"},
|
||||
"supervisor": {"name": "supervisor", "description": "Supervisor pattern"},
|
||||
}
|
||||
|
||||
SAMPLE_TASK_TYPES = {
|
||||
"general": {
|
||||
"name": "general",
|
||||
"description": "General queries",
|
||||
"valid_patterns": ["react", "plan-then-execute", "supervisor"],
|
||||
"framing": "",
|
||||
},
|
||||
"research": {
|
||||
"name": "research",
|
||||
"description": "Research queries",
|
||||
"valid_patterns": ["react", "plan-then-execute"],
|
||||
"framing": "Focus on gathering information.",
|
||||
},
|
||||
"summarisation": {
|
||||
"name": "summarisation",
|
||||
"description": "Summarisation queries",
|
||||
"valid_patterns": ["react"],
|
||||
"framing": "Focus on concise synthesis.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestMetaRouterInit:
|
||||
|
||||
def test_defaults_when_no_config(self):
|
||||
router = MetaRouter()
|
||||
assert "react" in router.patterns
|
||||
assert "general" in router.task_types
|
||||
|
||||
def test_loads_patterns_from_config(self):
|
||||
config = _make_config(patterns=SAMPLE_PATTERNS)
|
||||
router = MetaRouter(config=config)
|
||||
assert set(router.patterns.keys()) == {"react", "plan-then-execute", "supervisor"}
|
||||
|
||||
def test_loads_task_types_from_config(self):
|
||||
config = _make_config(task_types=SAMPLE_TASK_TYPES)
|
||||
router = MetaRouter(config=config)
|
||||
assert set(router.task_types.keys()) == {"general", "research", "summarisation"}
|
||||
|
||||
def test_handles_invalid_json_in_config(self):
|
||||
config = {
|
||||
"agent-pattern": {"react": "not valid json"},
|
||||
}
|
||||
router = MetaRouter(config=config)
|
||||
assert "react" in router.patterns
|
||||
assert router.patterns["react"]["name"] == "react"
|
||||
|
||||
|
||||
class TestIdentifyTaskType:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_llm_when_single_task_type(self):
|
||||
router = MetaRouter() # Only "general"
|
||||
context = _make_context("should not be called")
|
||||
|
||||
task_type, framing = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == "general"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_llm_when_multiple_task_types(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("research")
|
||||
|
||||
task_type, framing = await router.identify_task_type(
|
||||
"Research the topic", context,
|
||||
)
|
||||
|
||||
assert task_type == "research"
|
||||
assert framing == "Focus on gathering information."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_llm_returning_quoted_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context('"summarisation"')
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"Summarise this", context,
|
||||
)
|
||||
|
||||
assert task_type == "summarisation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_unknown_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("nonexistent-type")
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == DEFAULT_TASK_TYPE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_llm_error(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
context = lambda name: client
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == DEFAULT_TASK_TYPE
|
||||
|
||||
|
||||
class TestSelectPattern:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_llm_when_single_valid_pattern(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("should not be called")
|
||||
|
||||
# summarisation only has ["react"]
|
||||
pattern = await router.select_pattern(
|
||||
"Summarise this", "summarisation", context,
|
||||
)
|
||||
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_llm_when_multiple_valid_patterns(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("plan-then-execute")
|
||||
|
||||
# research has ["react", "plan-then-execute"]
|
||||
pattern = await router.select_pattern(
|
||||
"Research this", "research", context,
|
||||
)
|
||||
|
||||
assert pattern == "plan-then-execute"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_valid_patterns_constraint(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
# LLM returns supervisor, but research doesn't allow it
|
||||
context = _make_context("supervisor")
|
||||
|
||||
pattern = await router.select_pattern(
|
||||
"Research this", "research", context,
|
||||
)
|
||||
|
||||
# Should fall back to first valid pattern
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_llm_error(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
context = lambda name: client
|
||||
|
||||
# general has ["react", "plan-then-execute", "supervisor"]
|
||||
pattern = await router.select_pattern(
|
||||
"test", "general", context,
|
||||
)
|
||||
|
||||
# Falls back to first valid pattern
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_default_for_unknown_task_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("react")
|
||||
|
||||
# Unknown task type — valid_patterns falls back to all patterns
|
||||
pattern = await router.select_pattern(
|
||||
"test", "unknown-type", context,
|
||||
)
|
||||
|
||||
assert pattern == "react"
|
||||
|
||||
|
||||
class TestRoute:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_routing_pipeline(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
# Mock context where prompt returns different values per call
|
||||
client = AsyncMock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_prompt(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "research" # task type
|
||||
return "plan-then-execute" # pattern
|
||||
|
||||
client.prompt = mock_prompt
|
||||
context = lambda name: client
|
||||
|
||||
pattern, task_type, framing = await router.route(
|
||||
"Research the relationships", context,
|
||||
)
|
||||
|
||||
assert task_type == "research"
|
||||
assert pattern == "plan-then-execute"
|
||||
assert framing == "Focus on gathering information."
|
||||
132
tests/unit/test_agent/test_on_action_callback.py
Normal file
132
tests/unit/test_agent/test_on_action_callback.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Tests for the on_action callback in react() — verifies that it fires
|
||||
after action selection but before tool execution.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
|
||||
|
||||
class TestOnActionCallback:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_called_for_tool_use(self):
|
||||
"""on_action fires when react() selects a tool (not Final)."""
|
||||
call_log = []
|
||||
|
||||
async def fake_on_action(act):
|
||||
call_log.append(("on_action", act.name))
|
||||
|
||||
# Tool that records when it's invoked
|
||||
async def tool_invoke(**kwargs):
|
||||
call_log.append(("tool_invoke",))
|
||||
return "tool result"
|
||||
|
||||
tool_impl = MagicMock()
|
||||
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
|
||||
|
||||
tools = {
|
||||
"search": Tool(
|
||||
name="search",
|
||||
description="Search",
|
||||
implementation=tool_impl,
|
||||
arguments=[Argument(name="query", type="string", description="q")],
|
||||
config={},
|
||||
),
|
||||
}
|
||||
|
||||
agent = AgentManager(tools=tools)
|
||||
|
||||
# Mock reason() to return an Action
|
||||
action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="")
|
||||
agent.reason = AsyncMock(return_value=action)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
on_action=fake_on_action,
|
||||
)
|
||||
|
||||
# on_action should fire before tool_invoke
|
||||
assert len(call_log) == 2
|
||||
assert call_log[0] == ("on_action", "search")
|
||||
assert call_log[1] == ("tool_invoke",)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_not_called_for_final(self):
|
||||
"""on_action does not fire when react() returns Final."""
|
||||
called = []
|
||||
|
||||
async def fake_on_action(act):
|
||||
called.append(act)
|
||||
|
||||
agent = AgentManager(tools={})
|
||||
agent.reason = AsyncMock(
|
||||
return_value=Final(thought="done", final="answer")
|
||||
)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
result = await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
on_action=fake_on_action,
|
||||
)
|
||||
|
||||
assert isinstance(result, Final)
|
||||
assert len(called) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_none_accepted(self):
|
||||
"""react() works fine when on_action is None (default)."""
|
||||
async def tool_invoke(**kwargs):
|
||||
return "result"
|
||||
|
||||
tool_impl = MagicMock()
|
||||
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
|
||||
|
||||
tools = {
|
||||
"search": Tool(
|
||||
name="search",
|
||||
description="Search",
|
||||
implementation=tool_impl,
|
||||
arguments=[],
|
||||
config={},
|
||||
),
|
||||
}
|
||||
|
||||
agent = AgentManager(tools=tools)
|
||||
agent.reason = AsyncMock(
|
||||
return_value=Action(thought="t", name="search", arguments={}, observation="")
|
||||
)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
result = await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
# on_action not passed — defaults to None
|
||||
)
|
||||
|
||||
assert isinstance(result, Action)
|
||||
assert result.observation == "result"
|
||||
|
|
@ -0,0 +1,655 @@
|
|||
"""
|
||||
Integration tests for agent-orchestrator provenance chains.
|
||||
|
||||
Tests all three patterns by calling iterate() with mocked dependencies
|
||||
and verifying the explain events emitted via respond().
|
||||
|
||||
Provenance chains:
|
||||
React: session → iteration → (observation or final)
|
||||
Plan: session → plan → step-result(s) → synthesis
|
||||
Supervisor: session → decomposition → finding(s) → synthesis
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
# Agent provenance type constants
|
||||
from trustgraph.provenance.namespaces import (
|
||||
TG_AGENT_QUESTION,
|
||||
TG_ANALYSIS,
|
||||
TG_TOOL_USE,
|
||||
TG_OBSERVATION_TYPE,
|
||||
TG_CONCLUSION,
|
||||
TG_DECOMPOSITION,
|
||||
TG_FINDING,
|
||||
TG_PLAN_TYPE,
|
||||
TG_STEP_RESULT,
|
||||
TG_SYNTHESIS as TG_AGENT_SYNTHESIS,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
def collect_explain_events(respond_mock):
|
||||
"""Extract explain events from a respond mock's call history."""
|
||||
events = []
|
||||
for call in respond_mock.call_args_list:
|
||||
resp = call[0][0]
|
||||
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
|
||||
events.append({
|
||||
"explain_id": resp.explain_id,
|
||||
"explain_graph": resp.explain_graph,
|
||||
"triples": resp.explain_triples,
|
||||
})
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock processor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_mock_processor(tools=None):
|
||||
"""Build a mock processor with the minimal interface patterns need."""
|
||||
processor = MagicMock()
|
||||
processor.max_iterations = 10
|
||||
processor.save_answer_content = AsyncMock()
|
||||
|
||||
# provenance_session_uri must return a real URI
|
||||
def mock_session_uri(session_id):
|
||||
return f"urn:trustgraph:agent:session:{session_id}"
|
||||
processor.provenance_session_uri.side_effect = mock_session_uri
|
||||
|
||||
# Agent with tools
|
||||
agent = MagicMock()
|
||||
agent.tools = tools or {}
|
||||
agent.additional_context = ""
|
||||
processor.agent = agent
|
||||
|
||||
# Aggregator for supervisor
|
||||
processor.aggregator = MagicMock()
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def make_mock_flow():
|
||||
"""Build a mock flow that returns async mock producers."""
|
||||
producers = {}
|
||||
|
||||
def flow_factory(name):
|
||||
if name not in producers:
|
||||
producers[name] = AsyncMock()
|
||||
return producers[name]
|
||||
|
||||
flow = MagicMock(side_effect=flow_factory)
|
||||
flow._producers = producers
|
||||
return flow
|
||||
|
||||
|
||||
def make_base_request(**kwargs):
|
||||
"""Build a minimal AgentRequest."""
|
||||
defaults = dict(
|
||||
question="What is quantum computing?",
|
||||
state="",
|
||||
group=[],
|
||||
history=[],
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id="test-session-123",
|
||||
conversation_id="",
|
||||
pattern="react",
|
||||
task_type="",
|
||||
framing="",
|
||||
correlation_id="",
|
||||
parent_session_id="",
|
||||
subagent_goal="",
|
||||
expected_siblings=0,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return AgentRequest(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# React pattern tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReactPatternProvenance:
|
||||
"""
|
||||
React pattern chain: session → iteration → final
|
||||
(single iteration ending in Final answer)
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_iteration_final_answer(self):
|
||||
"""
|
||||
A single react iteration that produces a Final answer should emit:
|
||||
session, iteration, final — in that order.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.react.types import Action, Final
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = ReactPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
request = make_base_request()
|
||||
|
||||
# Mock AgentManager.react to call on_action then return Final
|
||||
with patch(
|
||||
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
final = Final(
|
||||
thought="I know the answer",
|
||||
final="Quantum computing uses qubits.",
|
||||
)
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
# Simulate the on_action callback before returning Final
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="I know the answer",
|
||||
name="final",
|
||||
arguments={},
|
||||
observation="",
|
||||
))
|
||||
return final
|
||||
|
||||
mock_am.react.side_effect = mock_react
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have 3 events: session, iteration, final
|
||||
assert len(events) == 3, (
|
||||
f"Expected 3 explain events (session, iteration, final), "
|
||||
f"got {len(events)}: {[e['explain_id'] for e in events]}"
|
||||
)
|
||||
|
||||
# Check types
|
||||
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
|
||||
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
|
||||
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_CONCLUSION)
|
||||
|
||||
# Check derivation chain
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
|
||||
# iteration derives from session
|
||||
assert derived_from(all_triples, uris[1]) == uris[0]
|
||||
# final derives from session (first iteration, no prior observation)
|
||||
assert derived_from(all_triples, uris[2]) == uris[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_iteration_with_tool_call(self):
|
||||
"""
|
||||
A react iteration that calls a tool (not Final) should emit:
|
||||
session, iteration, observation — then call next() for continuation.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.react.types import Action
|
||||
|
||||
# Create a mock tool
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "knowledge-query"
|
||||
mock_tool.description = "Query the knowledge base"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="The answer is 42")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = make_mock_processor(
|
||||
tools={"knowledge-query": mock_tool}
|
||||
)
|
||||
pattern = ReactPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
request = make_base_request()
|
||||
|
||||
action = Action(
|
||||
thought="I need to look this up",
|
||||
name="knowledge-query",
|
||||
arguments={"question": "What is quantum computing?"},
|
||||
observation="Quantum computing uses qubits.",
|
||||
)
|
||||
|
||||
with patch(
|
||||
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
return action
|
||||
|
||||
mock_am.react.side_effect = mock_react
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have 3 events: session, iteration, observation
|
||||
assert len(events) == 3, (
|
||||
f"Expected 3 explain events (session, iteration, observation), "
|
||||
f"got {len(events)}: {[e['explain_id'] for e in events]}"
|
||||
)
|
||||
|
||||
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
|
||||
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_ANALYSIS)
|
||||
assert has_type(events[2]["triples"], events[2]["explain_id"], TG_OBSERVATION_TYPE)
|
||||
|
||||
# next() should have been called to continue the loop
|
||||
assert next_fn.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
"""All explain triples should be in urn:graph:retrieval."""
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.react.types import Action, Final
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = ReactPattern(processor)
|
||||
respond = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
with patch(
|
||||
'trustgraph.agent.orchestrator.react_pattern.AgentManager'
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="done", name="final",
|
||||
arguments={}, observation="",
|
||||
))
|
||||
return Final(thought="done", final="answer")
|
||||
|
||||
mock_am.react.side_effect = mock_react
|
||||
await pattern.iterate(
|
||||
make_base_request(), respond, AsyncMock(), flow,
|
||||
)
|
||||
|
||||
for event in collect_explain_events(respond):
|
||||
for t in event["triples"]:
|
||||
assert t.g == GRAPH_RETRIEVAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plan-then-execute pattern tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPlanPatternProvenance:
|
||||
"""
|
||||
Plan pattern chain:
|
||||
Planning iteration: session → plan
|
||||
Execution iterations: step-result(s) → synthesis
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_planning_iteration_emits_session_and_plan(self):
|
||||
"""
|
||||
The first iteration (planning) should emit:
|
||||
session, plan — then call next() with the plan in history.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = PlanThenExecutePattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt client for plan creation
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
]
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = make_base_request(pattern="plan")
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have 2 events: session, plan
|
||||
assert len(events) == 2, (
|
||||
f"Expected 2 explain events (session, plan), "
|
||||
f"got {len(events)}: {[e['explain_id'] for e in events]}"
|
||||
)
|
||||
|
||||
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
|
||||
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_PLAN_TYPE)
|
||||
|
||||
# Plan should derive from session
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
|
||||
|
||||
# next() should have been called with plan in history
|
||||
assert next_fn.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execution_iteration_emits_step_result(self):
|
||||
"""
|
||||
An execution iteration should emit a step-result event.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
|
||||
# Create a mock tool
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "knowledge-query"
|
||||
mock_tool.description = "Query KB"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="Found the answer")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = make_mock_processor(
|
||||
tools={"knowledge-query": mock_tool}
|
||||
)
|
||||
pattern = PlanThenExecutePattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt for step execution
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = {
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
}
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
# Request with plan already in history (second iteration)
|
||||
plan_step = AgentStep(
|
||||
thought="Created plan",
|
||||
action="plan",
|
||||
arguments={},
|
||||
observation="[]",
|
||||
step_type="plan",
|
||||
plan=[
|
||||
PlanStep(goal="Find info", tool_hint="knowledge-query",
|
||||
depends_on=[], status="pending", result=""),
|
||||
],
|
||||
)
|
||||
request = make_base_request(
|
||||
pattern="plan",
|
||||
history=[plan_step],
|
||||
)
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have step-result (no session on iteration > 1)
|
||||
step_events = [
|
||||
e for e in events
|
||||
if has_type(e["triples"], e["explain_id"], TG_STEP_RESULT)
|
||||
]
|
||||
assert len(step_events) == 1, (
|
||||
f"Expected 1 step-result event, got {len(step_events)}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_after_all_steps_complete(self):
|
||||
"""
|
||||
When all plan steps are completed, synthesis should be emitted.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = PlanThenExecutePattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The synthesised answer."
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
# Request with all steps completed
|
||||
exec_step = AgentStep(
|
||||
thought="Executing step",
|
||||
action="knowledge-query",
|
||||
arguments={},
|
||||
observation="Result",
|
||||
step_type="execute",
|
||||
plan=[
|
||||
PlanStep(goal="Find info", tool_hint="knowledge-query",
|
||||
depends_on=[], status="completed", result="Found it"),
|
||||
],
|
||||
)
|
||||
request = make_base_request(
|
||||
pattern="plan",
|
||||
history=[exec_step],
|
||||
)
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have synthesis event
|
||||
synth_events = [
|
||||
e for e in events
|
||||
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
|
||||
]
|
||||
assert len(synth_events) == 1, (
|
||||
f"Expected 1 synthesis event, got {len(synth_events)}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supervisor pattern tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSupervisorPatternProvenance:
|
||||
"""
|
||||
Supervisor pattern chain:
|
||||
Decompose: session → decomposition
|
||||
(Fan-out to subagents happens externally)
|
||||
Synthesise: synthesis (derives from findings)
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_emits_session_and_decomposition(self):
|
||||
"""
|
||||
The decompose phase should emit: session, decomposition.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = SupervisorPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt for decomposition
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
]
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = make_base_request(pattern="supervisor")
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have 2 events: session, decomposition
|
||||
assert len(events) == 2, (
|
||||
f"Expected 2 explain events (session, decomposition), "
|
||||
f"got {len(events)}: {[e['explain_id'] for e in events]}"
|
||||
)
|
||||
|
||||
assert has_type(events[0]["triples"], events[0]["explain_id"], TG_AGENT_QUESTION)
|
||||
assert has_type(events[1]["triples"], events[1]["explain_id"], TG_DECOMPOSITION)
|
||||
|
||||
# Decomposition derives from session
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
assert derived_from(all_triples, events[1]["explain_id"]) == events[0]["explain_id"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_emits_after_subagent_results(self):
|
||||
"""
|
||||
When subagent results arrive, synthesis should be emitted.
|
||||
"""
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = SupervisorPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The combined answer."
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
# Request with subagent results in history
|
||||
synth_step = AgentStep(
|
||||
thought="",
|
||||
action="synthesise",
|
||||
arguments={},
|
||||
observation="",
|
||||
step_type="synthesise",
|
||||
subagent_results={
|
||||
"What is quantum computing?": "It uses qubits",
|
||||
"What are qubits?": "Quantum bits",
|
||||
},
|
||||
)
|
||||
request = make_base_request(
|
||||
pattern="supervisor",
|
||||
history=[synth_step],
|
||||
)
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
events = collect_explain_events(respond)
|
||||
|
||||
# Should have synthesis event (no session on iteration > 1)
|
||||
synth_events = [
|
||||
e for e in events
|
||||
if has_type(e["triples"], e["explain_id"], TG_AGENT_SYNTHESIS)
|
||||
]
|
||||
assert len(synth_events) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_fans_out_subagents(self):
|
||||
"""The decompose phase should call next() for each subagent goal."""
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = make_mock_processor()
|
||||
pattern = SupervisorPattern(processor)
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = make_mock_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = make_base_request(pattern="supervisor")
|
||||
|
||||
await pattern.iterate(request, respond, next_fn, flow)
|
||||
|
||||
# 3 subagent requests fanned out
|
||||
assert next_fn.call_count == 3
|
||||
74
tests/unit/test_agent/test_parse_chunk_message_id.py
Normal file
74
tests/unit/test_agent/test_parse_chunk_message_id.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""
|
||||
Tests that _parse_chunk propagates message_id from wire format
|
||||
to AgentThought, AgentObservation, and AgentAnswer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.api.socket_client import SocketClient
|
||||
from trustgraph.api.types import AgentThought, AgentObservation, AgentAnswer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# We only need _parse_chunk — don't connect
|
||||
c = object.__new__(SocketClient)
|
||||
return c
|
||||
|
||||
|
||||
class TestParseChunkMessageId:
|
||||
|
||||
def test_thought_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/thought",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentThought)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/thought"
|
||||
|
||||
def test_observation_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "observation",
|
||||
"content": "result",
|
||||
"end_of_message": True,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/observation",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentObservation)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/observation"
|
||||
|
||||
def test_answer_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"content": "the answer",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/final",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentAnswer)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/final"
|
||||
|
||||
def test_thought_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentThought)
|
||||
assert chunk.message_id == ""
|
||||
|
||||
def test_answer_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"content": "answer",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentAnswer)
|
||||
assert chunk.message_id == ""
|
||||
144
tests/unit/test_agent/test_pattern_base_subagent.py
Normal file
144
tests/unit/test_agent/test_pattern_base_subagent.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
Unit tests for PatternBase subagent helpers — is_subagent() and
|
||||
emit_subagent_completion().
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.schema import AgentRequest
|
||||
|
||||
from trustgraph.agent.orchestrator.pattern_base import PatternBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockProcessor:
|
||||
"""Minimal processor mock for PatternBase."""
|
||||
pass
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return AgentRequest(**defaults)
|
||||
|
||||
|
||||
def _make_pattern():
|
||||
return PatternBase(MockProcessor())
|
||||
|
||||
|
||||
class TestIsSubagent:
|
||||
|
||||
def test_returns_true_when_correlation_id_set(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(correlation_id="corr-123")
|
||||
assert pattern.is_subagent(request) is True
|
||||
|
||||
def test_returns_false_when_correlation_id_empty(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(correlation_id="")
|
||||
assert pattern.is_subagent(request) is False
|
||||
|
||||
def test_returns_false_when_correlation_id_missing(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request()
|
||||
assert pattern.is_subagent(request) is False
|
||||
|
||||
|
||||
class TestEmitSubagentCompletion:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_next_with_completion_request(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "The answer is Y",
|
||||
)
|
||||
|
||||
next_fn.assert_called_once()
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert isinstance(completion_req, AgentRequest)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_has_correct_step_type(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="What is X?",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer text",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert len(completion_req.history) == 1
|
||||
step = completion_req.history[0]
|
||||
assert step.step_type == "subagent-completion"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_carries_answer_in_observation(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="What is X?",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "The answer is Y",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
step = completion_req.history[0]
|
||||
assert step.observation == "The answer is Y"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_preserves_correlation_fields(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert completion_req.correlation_id == "corr-123"
|
||||
assert completion_req.parent_session_id == "parent-sess"
|
||||
assert completion_req.subagent_goal == "What is X?"
|
||||
assert completion_req.expected_siblings == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_has_empty_pattern(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="goal",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert completion_req.pattern == ""
|
||||
226
tests/unit/test_agent/test_provenance_triples.py
Normal file
226
tests/unit/test_agent/test_provenance_triples.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
Unit tests for orchestrator provenance triple builders.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.provenance import (
|
||||
agent_decomposition_triples,
|
||||
agent_finding_triples,
|
||||
agent_plan_triples,
|
||||
agent_step_result_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
)
|
||||
|
||||
|
||||
def _triple_set(triples):
|
||||
"""Convert triples to a set of (s_iri, p_iri, o_value) for easy assertion."""
|
||||
result = set()
|
||||
for t in triples:
|
||||
s = t.s.iri
|
||||
p = t.p.iri
|
||||
o = t.o.iri if t.o.iri else t.o.value
|
||||
result.add((s, p, o))
|
||||
return result
|
||||
|
||||
|
||||
def _has_type(triples, uri, rdf_type):
|
||||
"""Check if a URI has a given rdf:type in the triples."""
|
||||
return (uri, RDF_TYPE, rdf_type) in _triple_set(triples)
|
||||
|
||||
|
||||
def _get_values(triples, uri, predicate):
|
||||
"""Get all object values for a given subject + predicate."""
|
||||
ts = _triple_set(triples)
|
||||
return [o for s, p, o in ts if s == uri and p == predicate]
|
||||
|
||||
|
||||
class TestDecompositionTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a", "goal-b"],
|
||||
)
|
||||
assert _has_type(triples, "urn:decompose", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:decompose", TG_DECOMPOSITION)
|
||||
|
||||
def test_not_answer_type(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a"],
|
||||
)
|
||||
assert not _has_type(triples, "urn:decompose", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_session(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a"],
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:decompose", PROV_WAS_DERIVED_FROM, "urn:session") in ts
|
||||
|
||||
def test_includes_goals(self):
|
||||
goals = ["What is X?", "What is Y?", "What is Z?"]
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", goals,
|
||||
)
|
||||
values = _get_values(triples, "urn:decompose", TG_SUBAGENT_GOAL)
|
||||
assert set(values) == set(goals)
|
||||
|
||||
def test_label_includes_count(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["a", "b", "c"],
|
||||
)
|
||||
labels = _get_values(triples, "urn:decompose", RDFS_LABEL)
|
||||
assert any("3" in label for label in labels)
|
||||
|
||||
|
||||
class TestFindingTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
assert _has_type(triples, "urn:finding", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:finding", TG_FINDING)
|
||||
assert _has_type(triples, "urn:finding", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_decomposition(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:finding", PROV_WAS_DERIVED_FROM, "urn:decompose") in ts
|
||||
|
||||
def test_includes_goal(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_SUBAGENT_GOAL)
|
||||
assert "What is X?" in values
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "goal",
|
||||
document_id="urn:doc/1",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
|
||||
assert "urn:doc/1" in values
|
||||
|
||||
def test_no_document_when_none(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "goal",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
|
||||
assert values == []
|
||||
|
||||
|
||||
class TestPlanTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
assert _has_type(triples, "urn:plan", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:plan", TG_PLAN_TYPE)
|
||||
|
||||
def test_not_answer_type(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
assert not _has_type(triples, "urn:plan", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_session(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:plan", PROV_WAS_DERIVED_FROM, "urn:session") in ts
|
||||
|
||||
def test_includes_steps(self):
|
||||
steps = ["Define X", "Research Y", "Analyse Z"]
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", steps,
|
||||
)
|
||||
values = _get_values(triples, "urn:plan", TG_PLAN_STEP)
|
||||
assert set(values) == set(steps)
|
||||
|
||||
def test_label_includes_count(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["a", "b"],
|
||||
)
|
||||
labels = _get_values(triples, "urn:plan", RDFS_LABEL)
|
||||
assert any("2" in label for label in labels)
|
||||
|
||||
|
||||
class TestStepResultTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
assert _has_type(triples, "urn:step", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:step", TG_STEP_RESULT)
|
||||
assert _has_type(triples, "urn:step", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_plan(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:step", PROV_WAS_DERIVED_FROM, "urn:plan") in ts
|
||||
|
||||
def test_includes_goal(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
values = _get_values(triples, "urn:step", TG_PLAN_STEP)
|
||||
assert "Define X" in values
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "goal",
|
||||
document_id="urn:doc/step",
|
||||
)
|
||||
values = _get_values(triples, "urn:step", TG_DOCUMENT)
|
||||
assert "urn:doc/step" in values
|
||||
|
||||
|
||||
class TestSynthesisTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
)
|
||||
assert _has_type(triples, "urn:synthesis", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:synthesis", TG_SYNTHESIS)
|
||||
assert _has_type(triples, "urn:synthesis", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_previous(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:last-finding",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:synthesis", PROV_WAS_DERIVED_FROM,
|
||||
"urn:last-finding") in ts
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
document_id="urn:doc/synthesis",
|
||||
)
|
||||
values = _get_values(triples, "urn:synthesis", TG_DOCUMENT)
|
||||
assert "urn:doc/synthesis" in values
|
||||
|
||||
def test_label_is_synthesis(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
)
|
||||
labels = _get_values(triples, "urn:synthesis", RDFS_LABEL)
|
||||
assert "Synthesis" in labels
|
||||
323
tests/unit/test_base/test_async_processor_config.py
Normal file
323
tests/unit/test_base/test_async_processor_config.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""
|
||||
Tests for AsyncProcessor config notify pattern:
|
||||
- register_config_handler with types filtering
|
||||
- on_config_notify version comparison and type matching
|
||||
- fetch_config with short-lived client
|
||||
- fetch_and_apply_config retry logic
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, Mock
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
# Patch heavy dependencies before importing AsyncProcessor
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
"""Create an AsyncProcessor with mocked dependencies."""
|
||||
with patch('trustgraph.base.async_processor.get_pubsub') as mock_pubsub, \
|
||||
patch('trustgraph.base.async_processor.Consumer') as mock_consumer, \
|
||||
patch('trustgraph.base.async_processor.ProcessorMetrics') as mock_pm, \
|
||||
patch('trustgraph.base.async_processor.ConsumerMetrics') as mock_cm:
|
||||
|
||||
mock_pubsub.return_value = MagicMock()
|
||||
mock_consumer.return_value = MagicMock()
|
||||
mock_pm.return_value = MagicMock()
|
||||
mock_cm.return_value = MagicMock()
|
||||
|
||||
from trustgraph.base.async_processor import AsyncProcessor
|
||||
p = AsyncProcessor(
|
||||
id="test-processor",
|
||||
taskgroup=AsyncMock(),
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
class TestRegisterConfigHandler:
|
||||
|
||||
def test_register_without_types(self, processor):
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler)
|
||||
|
||||
assert len(processor.config_handlers) == 1
|
||||
assert processor.config_handlers[0]["handler"] is handler
|
||||
assert processor.config_handlers[0]["types"] is None
|
||||
|
||||
def test_register_with_types(self, processor):
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
assert processor.config_handlers[0]["types"] == {"prompt"}
|
||||
|
||||
def test_register_multiple_types(self, processor):
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(
|
||||
handler, types=["schema", "collection"]
|
||||
)
|
||||
|
||||
assert processor.config_handlers[0]["types"] == {
|
||||
"schema", "collection"
|
||||
}
|
||||
|
||||
def test_register_multiple_handlers(self, processor):
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
|
||||
assert len(processor.config_handlers) == 2
|
||||
|
||||
|
||||
class TestOnConfigNotify:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_old_version(self, processor):
|
||||
processor.config_version = 5
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=3, types=["prompt"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_same_version(self, processor):
|
||||
processor.config_version = 5
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=5, types=["prompt"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_irrelevant_types(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["schema"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
# Version should still be updated
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_on_relevant_type(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
# Mock fetch_config
|
||||
mock_config = {"prompt": {"key": "value"}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_without_types_always_called(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler) # No types = all
|
||||
|
||||
mock_config = {"anything": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["whatever"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_handlers_type_filtering(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
prompt_handler = AsyncMock()
|
||||
schema_handler = AsyncMock()
|
||||
all_handler = AsyncMock()
|
||||
|
||||
processor.register_config_handler(prompt_handler, types=["prompt"])
|
||||
processor.register_config_handler(schema_handler, types=["schema"])
|
||||
processor.register_config_handler(all_handler)
|
||||
|
||||
mock_config = {"prompt": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
prompt_handler.assert_called_once()
|
||||
schema_handler.assert_not_called()
|
||||
all_handler.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_types_invokes_all(self, processor):
|
||||
"""Empty types list (startup signal) should invoke all handlers."""
|
||||
processor.config_version = 1
|
||||
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
|
||||
mock_config = {}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=[])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
h1.assert_called_once()
|
||||
h2.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_failure_handled(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler)
|
||||
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("Connection failed")
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
# Should not raise
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
|
||||
class TestFetchConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_returns_config_and_version(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.config = {"prompt": {"key": "val"}}
|
||||
mock_resp.version = 42
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
config, version = await processor.fetch_config()
|
||||
|
||||
assert config == {"prompt": {"key": "val"}}
|
||||
assert version == 42
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_raises_on_error_response(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = Mock(message="not found")
|
||||
mock_resp.config = {}
|
||||
mock_resp.version = 0
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="Config error"):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_stops_client_on_exception(self, processor):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = TimeoutError("timeout")
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(TimeoutError):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestFetchAndApplyConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_applies_config_to_all_handlers(self, processor):
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
|
||||
mock_config = {"prompt": {}, "schema": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 10)
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
# On startup, all handlers are invoked regardless of type
|
||||
h1.assert_called_once_with(mock_config, 10)
|
||||
h2.assert_called_once_with(mock_config, 10)
|
||||
assert processor.config_version == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
call_count = 0
|
||||
mock_config = {"prompt": {}}
|
||||
|
||||
async def mock_fetch():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError("not ready")
|
||||
return mock_config, 5
|
||||
|
||||
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
|
||||
patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
assert call_count == 3
|
||||
assert processor.config_version == 5
|
||||
|
|
@ -35,7 +35,9 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_async_init.assert_called_once()
|
||||
|
||||
# Verify register_config_handler was called with the correct handler
|
||||
mock_register_config.assert_called_once_with(processor.on_configure_flows)
|
||||
mock_register_config.assert_called_once_with(
|
||||
processor.on_configure_flows, types=["active-flow"]
|
||||
)
|
||||
|
||||
# Verify FlowProcessor-specific initialization
|
||||
assert hasattr(processor, 'flows')
|
||||
|
|
|
|||
|
|
@ -61,23 +61,21 @@ async def test_subscriber_deferred_acknowledgment_success():
|
|||
max_size=10,
|
||||
backpressure_strategy="block"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
|
||||
subscriber.consumer = mock_consumer
|
||||
|
||||
# Create queue for subscription
|
||||
queue = await subscriber.subscribe("test-queue")
|
||||
|
||||
|
||||
# Create mock message with matching queue name
|
||||
msg = create_mock_message("test-queue", {"data": "test"})
|
||||
|
||||
|
||||
# Process message
|
||||
await subscriber._process_message(msg)
|
||||
|
||||
|
||||
# Should acknowledge successful delivery
|
||||
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||
mock_consumer.negative_acknowledge.assert_not_called()
|
||||
|
||||
|
||||
# Message should be in queue
|
||||
assert not queue.empty()
|
||||
received_msg = await queue.get()
|
||||
|
|
@ -108,9 +106,7 @@ async def test_subscriber_dropped_message_still_acks():
|
|||
max_size=1, # Very small queue
|
||||
backpressure_strategy="drop_new"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
subscriber.consumer = mock_consumer
|
||||
|
||||
# Create queue and fill it
|
||||
queue = await subscriber.subscribe("test-queue")
|
||||
|
|
@ -151,9 +147,7 @@ async def test_subscriber_orphaned_message_acks():
|
|||
max_size=10,
|
||||
backpressure_strategy="block"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
subscriber.consumer = mock_consumer
|
||||
|
||||
# Don't create any queues - message will be orphaned
|
||||
# This simulates a response arriving after the waiter has unsubscribed
|
||||
|
|
@ -189,9 +183,7 @@ async def test_subscriber_backpressure_strategies():
|
|||
max_size=2,
|
||||
backpressure_strategy="drop_oldest"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
subscriber.consumer = mock_consumer
|
||||
|
||||
queue = await subscriber.subscribe("test-queue")
|
||||
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ class MockAsyncProcessor:
|
|||
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Recursive chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
||||
"""Test basic processor initialization"""
|
||||
|
|
@ -51,8 +51,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
|
|
@ -71,7 +71,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 2000, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -85,8 +85,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 2000 # Should use overridden value
|
||||
assert chunk_overlap == 100 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
|
|
@ -105,7 +105,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 200 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -119,8 +119,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 1000 # Should use default value
|
||||
assert chunk_overlap == 200 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
|
|
@ -139,7 +139,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 1500, # Override chunk size
|
||||
"chunk-overlap": 150 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -153,8 +153,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 1500 # Should use overridden value
|
||||
assert chunk_overlap == 150 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
||||
|
|
@ -177,7 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -196,12 +196,14 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 1500,
|
||||
"chunk-overlap": 150,
|
||||
}.get(param)
|
||||
mock_flow.side_effect = lambda name: {
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(param)
|
||||
}.get(name)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
|
@ -219,8 +221,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
|
|
@ -239,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = None # No overrides
|
||||
mock_flow.parameters.get.return_value = None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ class MockAsyncProcessor:
|
|||
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Token chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
||||
"""Test basic processor initialization"""
|
||||
|
|
@ -51,8 +51,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
|
|
@ -71,7 +71,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 400, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -85,8 +85,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 400 # Should use overridden value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
|
|
@ -105,7 +105,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 25 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -119,8 +119,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 25 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
|
|
@ -139,7 +139,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 350, # Override chunk size
|
||||
"chunk-overlap": 30 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -153,8 +153,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 350 # Should use overridden value
|
||||
assert chunk_overlap == 30 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
||||
|
|
@ -177,7 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid librarian producer interactions
|
||||
processor.save_child_document = AsyncMock(return_value="chunk-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -196,12 +196,14 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 400,
|
||||
"chunk-overlap": 40,
|
||||
}.get(param)
|
||||
mock_flow.side_effect = lambda name: {
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(param)
|
||||
}.get(name)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
|
@ -223,8 +225,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
|
|
@ -243,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = None # No overrides
|
||||
mock_flow.parameters.get.return_value = None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -254,8 +256,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
|
||||
"""Test that token chunker has different defaults than recursive chunker"""
|
||||
|
|
|
|||
|
|
@ -21,17 +21,15 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
|
||||
# Act
|
||||
client = DocumentEmbeddingsClient(
|
||||
log_level=1,
|
||||
subscriber="test-subscriber",
|
||||
input_queue="test-input",
|
||||
output_queue="test-output",
|
||||
pulsar_host="pulsar://test:6650",
|
||||
pulsar_api_key="test-key"
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
mock_base_init.assert_called_once_with(
|
||||
log_level=1,
|
||||
subscriber="test-subscriber",
|
||||
input_queue="test-input",
|
||||
output_queue="test-output",
|
||||
|
|
|
|||
|
|
@ -81,9 +81,8 @@ class TestTaskGroupConcurrency:
|
|||
|
||||
# Track how many consume_from_queue calls are made
|
||||
call_count = 0
|
||||
original_running = True
|
||||
|
||||
async def mock_consume():
|
||||
async def mock_consume(backend_consumer, executor=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# Wait a bit to let all tasks start, then signal stop
|
||||
|
|
@ -107,7 +106,7 @@ class TestTaskGroupConcurrency:
|
|||
consumer = _make_consumer(concurrency=1)
|
||||
call_count = 0
|
||||
|
||||
async def mock_consume():
|
||||
async def mock_consume(backend_consumer, executor=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -147,7 +146,7 @@ class TestRateLimitRetry:
|
|||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
assert call_count == 2
|
||||
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
||||
|
|
@ -166,7 +165,7 @@ class TestRateLimitRetry:
|
|||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
|
||||
consumer.consumer.acknowledge.assert_not_called()
|
||||
|
|
@ -185,7 +184,7 @@ class TestRateLimitRetry:
|
|||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
assert call_count == 1
|
||||
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
|
||||
|
|
@ -197,7 +196,7 @@ class TestRateLimitRetry:
|
|||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
||||
|
||||
|
|
@ -219,7 +218,7 @@ class TestMetricsIntegration:
|
|||
mock_metrics.record_time.return_value.__exit__ = MagicMock()
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
mock_metrics.process.assert_called_once_with("success")
|
||||
|
||||
|
|
@ -235,7 +234,7 @@ class TestMetricsIntegration:
|
|||
mock_metrics = MagicMock()
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
mock_metrics.process.assert_called_once_with("error")
|
||||
|
||||
|
|
@ -261,7 +260,7 @@ class TestMetricsIntegration:
|
|||
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
mock_metrics.rate_limit.assert_called_once()
|
||||
|
||||
|
|
@ -294,9 +293,8 @@ class TestPollTimeout:
|
|||
raise type('Timeout', (Exception,), {})("timeout")
|
||||
|
||||
mock_pulsar_consumer.receive = capture_receive
|
||||
consumer.consumer = mock_pulsar_consumer
|
||||
|
||||
await consumer.consume_from_queue()
|
||||
await consumer.consume_from_queue(mock_pulsar_consumer)
|
||||
|
||||
assert received_kwargs.get("timeout_millis") == 100
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ class MockAsyncProcessor:
|
|||
class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test Mistral OCR processor functionality"""
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization_with_api_key(
|
||||
|
|
@ -51,8 +51,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization_without_api_key(
|
||||
self, mock_producer, mock_consumer
|
||||
|
|
@ -66,8 +66,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
|
||||
Processor(**config)
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_ocr_single_chunk(
|
||||
|
|
@ -131,8 +131,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
)
|
||||
mock_mistral.ocr.process.assert_called_once()
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_success(
|
||||
|
|
@ -172,7 +172,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
]
|
||||
|
||||
# Mock save_child_document
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
with patch.object(processor, 'ocr', return_value=ocr_result):
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
|
|
|||
|
|
@ -24,12 +24,10 @@ class MockAsyncProcessor:
|
|||
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test PDF decoder processor functionality"""
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
async def test_processor_initialization(self, mock_producer, mock_consumer):
|
||||
"""Test PDF decoder processor initialization"""
|
||||
config = {
|
||||
'id': 'test-pdf-decoder',
|
||||
|
|
@ -44,13 +42,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||
"""Test successful PDF processing"""
|
||||
# Mock PDF content
|
||||
pdf_content = b"fake pdf content"
|
||||
|
|
@ -85,7 +81,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
|
@ -94,13 +90,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
# Verify triples were sent for each page (provenance)
|
||||
assert mock_triples_flow.send.call_count == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||
"""Test handling of empty PDF"""
|
||||
pdf_content = b"fake pdf content"
|
||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||
|
|
@ -128,13 +122,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_output_flow.send.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||
"""Test handling of unicode content in PDF"""
|
||||
pdf_content = b"fake pdf content"
|
||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||
|
|
@ -165,7 +157,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
|
|
|||
|
|
@ -142,8 +142,8 @@ class TestPageBasedFormats:
|
|||
class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test universal decoder processor."""
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization(
|
||||
self, mock_producer, mock_consumer
|
||||
|
|
@ -169,8 +169,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_custom_strategy(
|
||||
self, mock_producer, mock_consumer
|
||||
|
|
@ -188,8 +188,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert processor.partition_strategy == "hi_res"
|
||||
assert processor.section_strategy_name == "heading"
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_group_by_page(self, mock_producer, mock_consumer):
|
||||
"""Test page grouping of elements."""
|
||||
|
|
@ -214,8 +214,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert result[1][0] == 2
|
||||
assert len(result[1][1]) == 1
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_inline_non_page(
|
||||
|
|
@ -255,7 +255,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
}.get(name))
|
||||
|
||||
# Mock save_child_document and magic
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "text/markdown"
|
||||
|
|
@ -271,8 +271,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert call_args.document_id.startswith("urn:section:")
|
||||
assert call_args.text == b""
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_page_based(
|
||||
|
|
@ -310,7 +310,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "application/pdf"
|
||||
|
|
@ -323,8 +323,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
call_args = mock_output_flow.send.call_args_list[0][0][0]
|
||||
assert call_args.document_id.startswith("urn:page:")
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_images_stored_not_emitted(
|
||||
|
|
@ -361,7 +361,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "application/pdf"
|
||||
|
|
@ -374,7 +374,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert mock_triples_flow.send.call_count == 2
|
||||
|
||||
# save_child_document called twice (page + image)
|
||||
assert processor.save_child_document.call_count == 2
|
||||
assert processor.librarian.save_child_document.call_count == 2
|
||||
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
||||
def test_add_args(self, mock_parent_add_args):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Tests for Gateway Config Receiver
|
|||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import Mock, patch, Mock, MagicMock
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
import uuid
|
||||
|
||||
from trustgraph.gateway.config.receiver import ConfigReceiver
|
||||
|
|
@ -23,174 +23,237 @@ class TestConfigReceiver:
|
|||
def test_config_receiver_initialization(self):
|
||||
"""Test ConfigReceiver initialization"""
|
||||
mock_backend = Mock()
|
||||
|
||||
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
||||
assert config_receiver.backend == mock_backend
|
||||
assert config_receiver.flow_handlers == []
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.config_version == 0
|
||||
|
||||
def test_add_handler(self):
|
||||
"""Test adding flow handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
||||
handler1 = Mock()
|
||||
handler2 = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
assert len(config_receiver.flow_handlers) == 2
|
||||
assert handler1 in config_receiver.flow_handlers
|
||||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_new_flows(self):
|
||||
"""Test on_config method with new flows"""
|
||||
async def test_on_config_notify_new_version(self):
|
||||
"""Test on_config_notify triggers fetch for newer version"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
start_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with flows
|
||||
config_receiver.config_version = 1
|
||||
|
||||
# Mock fetch_and_apply
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with newer version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}',
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows were added
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert config_receiver.flows["flow1"] == {"name": "test_flow_1", "steps": []}
|
||||
assert config_receiver.flows["flow2"] == {"name": "test_flow_2", "steps": []}
|
||||
|
||||
# Verify start_flow was called for each new flow
|
||||
assert len(start_flow_calls) == 2
|
||||
assert ("flow1", {"name": "test_flow_1", "steps": []}) in start_flow_calls
|
||||
assert ("flow2", {"name": "test_flow_2", "steps": []}) in start_flow_calls
|
||||
mock_msg.value.return_value = Mock(version=2, types=["flow"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_removed_flows(self):
|
||||
"""Test on_config method with removed flows"""
|
||||
async def test_on_config_notify_old_version_ignored(self):
|
||||
"""Test on_config_notify ignores older versions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message with only flow1 (flow2 removed)
|
||||
config_receiver.config_version = 5
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with older version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flow2 was removed
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
|
||||
# Verify stop_flow was called for removed flow
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow2", {"name": "test_flow_2", "steps": []})
|
||||
mock_msg.value.return_value = Mock(version=3, types=["flow"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_no_flows(self):
|
||||
"""Test on_config method with no flows in config"""
|
||||
async def test_on_config_notify_irrelevant_types_ignored(self):
|
||||
"""Test on_config_notify ignores types the gateway doesn't care about"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock the start_flow and stop_flow methods with async functions
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
async def mock_stop_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message without flows
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with non-flow type
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify no flows were added
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
# Since no flows were in the config, the flow methods shouldn't be called
|
||||
# (We can't easily assert this with simple async functions, but the test
|
||||
# passes if no exceptions are thrown)
|
||||
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
# Version should be updated but no fetch
|
||||
assert len(fetch_calls) == 0
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_exception_handling(self):
|
||||
"""Test on_config method handles exceptions gracefully"""
|
||||
async def test_on_config_notify_flow_type_triggers_fetch(self):
|
||||
"""Test on_config_notify fetches for flow-related types"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create mock message that will cause an exception
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
for type_name in ["flow", "active-flow"]:
|
||||
fetch_calls.clear()
|
||||
config_receiver.config_version = 1
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=[type_name])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_exception_handling(self):
|
||||
"""Test on_config_notify handles exceptions gracefully"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create notify message that causes an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows remain empty
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_new_flows(self):
|
||||
"""Test fetch_and_apply starts new flows"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock _create_config_client to return a mock client
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
"flow2": '{"name": "test_flow_2"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
start_flow_calls = []
|
||||
async def mock_start_flow(id, flow):
|
||||
start_flow_calls.append((id, flow))
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert config_receiver.config_version == 5
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert len(start_flow_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_removed_flows(self):
|
||||
"""Test fetch_and_apply stops removed flows"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
}
|
||||
|
||||
# Config now only has flow1
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
stop_flow_calls = []
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_flow_calls.append((id, flow))
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0][0] == "flow2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_no_flows(self):
|
||||
"""Test fetch_and_apply with empty config"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 1
|
||||
mock_resp.config = {}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.config_version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handlers
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
|
|
@ -199,19 +262,17 @@ class TestConfigReceiver:
|
|||
"""Test start_flow method handles handler exceptions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -219,21 +280,19 @@ class TestConfigReceiver:
|
|||
"""Test stop_flow method with multiple handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handlers
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
|
|
@ -242,167 +301,77 @@ class TestConfigReceiver:
|
|||
"""Test stop_flow method handles handler exceptions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_loader_creates_consumer(self):
|
||||
"""Test config_loader method creates Pulsar consumer"""
|
||||
mock_backend = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
# Temporarily restore the real config_loader for this test
|
||||
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
|
||||
|
||||
# Mock Consumer class
|
||||
with patch('trustgraph.gateway.config.receiver.Consumer') as mock_consumer_class, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_consumer = Mock()
|
||||
async def mock_start():
|
||||
pass
|
||||
mock_consumer.start = mock_start
|
||||
mock_consumer_class.return_value = mock_consumer
|
||||
|
||||
# Create a task that will complete quickly
|
||||
async def quick_task():
|
||||
await config_receiver.config_loader()
|
||||
|
||||
# Run the task with a timeout to prevent hanging
|
||||
try:
|
||||
await asyncio.wait_for(quick_task(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
# This is expected since the method runs indefinitely
|
||||
pass
|
||||
|
||||
# Verify Consumer was created with correct parameters
|
||||
mock_consumer_class.assert_called_once()
|
||||
call_args = mock_consumer_class.call_args
|
||||
|
||||
assert call_args[1]['backend'] == mock_backend
|
||||
assert call_args[1]['subscriber'] == "gateway-test-uuid"
|
||||
assert call_args[1]['handler'] == config_receiver.on_config
|
||||
assert call_args[1]['start_of_messages'] is True
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_config_loader_task(self, mock_create_task):
|
||||
"""Test start method creates config loader task"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock create_task to avoid actually creating tasks with real coroutines
|
||||
|
||||
mock_task = Mock()
|
||||
mock_create_task.return_value = mock_task
|
||||
|
||||
|
||||
await config_receiver.start()
|
||||
|
||||
# Verify task was created
|
||||
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
# Verify the argument passed to create_task is a coroutine
|
||||
call_args = mock_create_task.call_args[0]
|
||||
assert len(call_args) == 1 # Should have one argument (the coroutine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_mixed_flow_operations(self):
|
||||
"""Test on_config with mixed add/remove operations"""
|
||||
async def test_fetch_and_apply_mixed_flow_operations(self):
|
||||
"""Test fetch_and_apply with mixed add/remove operations"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
|
||||
# Pre-populate
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using Mock
|
||||
start_flow_calls = []
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
# Directly assign to avoid patch.object detecting async methods
|
||||
original_start_flow = config_receiver.start_flow
|
||||
original_stop_flow = config_receiver.stop_flow
|
||||
|
||||
# Config removes flow1, keeps flow2, adds flow3
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
"flow3": '{"name": "test_flow_3"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
start_calls = []
|
||||
stop_calls = []
|
||||
|
||||
async def mock_start_flow(id, flow):
|
||||
start_calls.append((id, flow))
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_calls.append((id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
try:
|
||||
|
||||
# Create mock message with flow1 removed and flow3 added
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}',
|
||||
"flow3": '{"name": "test_flow_3", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify final state
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
|
||||
# Verify operations
|
||||
assert len(start_flow_calls) == 1
|
||||
assert start_flow_calls[0] == ("flow3", {"name": "test_flow_3", "steps": []})
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow1", {"name": "test_flow_1", "steps": []})
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
config_receiver.start_flow = original_start_flow
|
||||
config_receiver.stop_flow = original_stop_flow
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_invalid_json_flow_data(self):
|
||||
"""Test on_config handles invalid JSON in flow data"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock the start_flow method with an async function
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with invalid JSON
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"invalid": json}', # Invalid JSON
|
||||
"flow2": '{"name": "valid_flow", "steps": []}' # Valid JSON
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# This should handle the exception gracefully
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# The entire operation should fail due to JSON parsing error
|
||||
# So no flows should be added
|
||||
assert config_receiver.flows == {}
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
assert len(start_calls) == 1
|
||||
assert start_calls[0][0] == "flow3"
|
||||
assert len(stop_calls) == 1
|
||||
assert stop_calls[0][0] == "flow1"
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class TestConfigRequestor:
|
|||
mock_translator_registry.get_response_translator.return_value = Mock()
|
||||
|
||||
# Setup translator response
|
||||
mock_request_translator.to_pulsar.return_value = "translated_request"
|
||||
mock_request_translator.decode.return_value = "translated_request"
|
||||
|
||||
# Patch ServiceRequestor async methods with regular mocks (not AsyncMock)
|
||||
with patch.object(ServiceRequestor, 'start', return_value=None), \
|
||||
|
|
@ -64,7 +64,7 @@ class TestConfigRequestor:
|
|||
result = requestor.to_request({"test": "body"})
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_request_translator.to_pulsar.assert_called_once_with({"test": "body"})
|
||||
mock_request_translator.decode.assert_called_once_with({"test": "body"})
|
||||
assert result == "translated_request"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
|
|
@ -76,7 +76,7 @@ class TestConfigRequestor:
|
|||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||
|
||||
# Setup translator response
|
||||
mock_response_translator.from_response_with_completion.return_value = "translated_response"
|
||||
mock_response_translator.encode_with_completion.return_value = "translated_response"
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
backend=Mock(),
|
||||
|
|
@ -89,5 +89,5 @@ class TestConfigRequestor:
|
|||
result = requestor.from_response(mock_message)
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_response_translator.from_response_with_completion.assert_called_once_with(mock_message)
|
||||
mock_response_translator.encode_with_completion.assert_called_once_with(mock_message)
|
||||
assert result == "translated_response"
|
||||
359
tests/unit/test_gateway/test_explain_triples.py
Normal file
359
tests/unit/test_gateway/test_explain_triples.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
"""
|
||||
Tests for inline explainability triples in response translators
|
||||
and ProvenanceEvent parsing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.schema import (
|
||||
GraphRagResponse, DocumentRagResponse, AgentResponse,
|
||||
Term, Triple, IRI, LITERAL, Error,
|
||||
)
|
||||
from trustgraph.messaging.translators.retrieval import (
|
||||
GraphRagResponseTranslator,
|
||||
DocumentRagResponseTranslator,
|
||||
)
|
||||
from trustgraph.messaging.translators.agent import (
|
||||
AgentResponseTranslator,
|
||||
)
|
||||
from trustgraph.api.types import ProvenanceEvent
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def make_triple(s_iri, p_iri, o_value, o_type=LITERAL):
|
||||
"""Create a Triple with IRI subject/predicate and typed object."""
|
||||
o = Term(type=IRI, iri=o_value) if o_type == IRI else Term(type=LITERAL, value=o_value)
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=s_iri),
|
||||
p=Term(type=IRI, iri=p_iri),
|
||||
o=o,
|
||||
)
|
||||
|
||||
|
||||
def sample_triples():
|
||||
"""A few provenance triples for a question entity."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/GraphRagQuestion",
|
||||
o_type=IRI,
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"https://trustgraph.ai/ns/query",
|
||||
"What is the internet?",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"http://www.w3.org/ns/prov#startedAtTime",
|
||||
"2026-04-07T09:00:00Z",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# --- GraphRag Translator ---
|
||||
|
||||
class TestGraphRagExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
triples = sample_triples()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=triples,
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
# Check first triple is properly encoded
|
||||
t = result["explain_triples"][0]
|
||||
assert t["s"]["t"] == "i"
|
||||
assert t["s"]["i"] == "urn:trustgraph:question:abc123"
|
||||
assert t["p"]["t"] == "i"
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Some answer text",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" not in result
|
||||
|
||||
def test_explain_with_completion_returns_not_final(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_session=False,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is False
|
||||
|
||||
def test_explain_id_and_graph_included(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert result["explain_id"] == "urn:trustgraph:question:abc123"
|
||||
assert result["explain_graph"] == "urn:graph:retrieval"
|
||||
|
||||
|
||||
# --- DocumentRag Translator ---
|
||||
|
||||
class TestDocumentRagExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
||||
response = DocumentRagResponse(
|
||||
response=None,
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:docrag:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
||||
response = DocumentRagResponse(
|
||||
response="Answer text",
|
||||
message_type="chunk",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert "explain_triples" not in result
|
||||
|
||||
|
||||
# --- Agent Translator ---
|
||||
|
||||
class TestAgentExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
t = result["explain_triples"][1]
|
||||
assert t["p"]["i"] == "https://trustgraph.ai/ns/query"
|
||||
assert t["o"]["t"] == "l"
|
||||
assert t["o"]["v"] == "What is the internet?"
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
content="I need to think...",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert "explain_triples" not in result
|
||||
|
||||
def test_explain_with_completion_not_final(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is False
|
||||
|
||||
def test_explain_with_completion_final(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="The answer is...",
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is True
|
||||
|
||||
|
||||
# --- ProvenanceEvent ---
|
||||
|
||||
class TestProvenanceEvent:
|
||||
|
||||
def test_question_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
)
|
||||
assert event.event_type == "question"
|
||||
|
||||
def test_exploration_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:exploration:abc123",
|
||||
)
|
||||
assert event.event_type == "exploration"
|
||||
|
||||
def test_focus_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:focus:abc123",
|
||||
)
|
||||
assert event.event_type == "focus"
|
||||
|
||||
def test_synthesis_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:synthesis:abc123",
|
||||
)
|
||||
assert event.event_type == "synthesis"
|
||||
|
||||
def test_grounding_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:grounding:abc123",
|
||||
)
|
||||
assert event.event_type == "grounding"
|
||||
|
||||
def test_session_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
)
|
||||
assert event.event_type == "session"
|
||||
|
||||
def test_iteration_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:iteration:abc123:1",
|
||||
)
|
||||
assert event.event_type == "iteration"
|
||||
|
||||
def test_observation_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:observation:abc123:1",
|
||||
)
|
||||
assert event.event_type == "observation"
|
||||
|
||||
def test_conclusion_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:conclusion:abc123",
|
||||
)
|
||||
assert event.event_type == "conclusion"
|
||||
|
||||
def test_decomposition_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:decomposition:abc123",
|
||||
)
|
||||
assert event.event_type == "decomposition"
|
||||
|
||||
def test_finding_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:finding:abc123:0",
|
||||
)
|
||||
assert event.event_type == "finding"
|
||||
|
||||
def test_plan_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:plan:abc123",
|
||||
)
|
||||
assert event.event_type == "plan"
|
||||
|
||||
def test_step_result_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:step-result:abc123:0",
|
||||
)
|
||||
assert event.event_type == "step-result"
|
||||
|
||||
def test_defaults(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
)
|
||||
assert event.entity is None
|
||||
assert event.triples == []
|
||||
assert event.explain_graph == ""
|
||||
|
||||
def test_with_triples(self):
|
||||
raw = [{"s": {"t": "i", "i": "urn:x"}, "p": {"t": "i", "i": "urn:y"}, "o": {"t": "l", "v": "z"}}]
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
triples=raw,
|
||||
)
|
||||
assert len(event.triples) == 1
|
||||
|
||||
|
||||
# --- Build ProvenanceEvent with entity parsing ---
|
||||
|
||||
class TestBuildProvenanceEvent:
|
||||
|
||||
def _make_client(self):
|
||||
"""Create a minimal WebSocketClient-like object with _build_provenance_event."""
|
||||
from trustgraph.api.socket_client import WebSocketClient
|
||||
# We can't instantiate WebSocketClient easily, so test the method logic directly
|
||||
return None
|
||||
|
||||
def test_entity_parsed_from_wire_triples(self):
|
||||
"""Test that wire-format triples are parsed into an ExplainEntity."""
|
||||
from trustgraph.api.explainability import ExplainEntity
|
||||
|
||||
wire_triples = [
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
|
||||
"p": {"t": "i", "i": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"},
|
||||
"o": {"t": "i", "i": "https://trustgraph.ai/ns/GraphRagQuestion"},
|
||||
},
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/query"},
|
||||
"o": {"t": "l", "v": "What is the internet?"},
|
||||
},
|
||||
]
|
||||
|
||||
# Parse triples the same way _build_provenance_event does
|
||||
parsed = []
|
||||
for t in wire_triples:
|
||||
s = t.get("s", {}).get("i", "")
|
||||
p = t.get("p", {}).get("i", "")
|
||||
o_term = t.get("o", {})
|
||||
if o_term.get("t") == "i":
|
||||
o = o_term.get("i", "")
|
||||
else:
|
||||
o = o_term.get("v", "")
|
||||
parsed.append((s, p, o))
|
||||
|
||||
entity = ExplainEntity.from_triples(
|
||||
"urn:trustgraph:question:abc123", parsed
|
||||
)
|
||||
|
||||
assert entity.entity_type == "question"
|
||||
assert entity.query == "What is the internet?"
|
||||
assert entity.question_type == "graph-rag"
|
||||
|
|
@ -25,7 +25,7 @@ from trustgraph.schema import (
|
|||
class TestGraphRagResponseTranslator:
|
||||
"""Test GraphRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
def test_encode_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -36,14 +36,14 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - Empty string should be included in result
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
def test_encode_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -54,13 +54,13 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Some text"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_response(self):
|
||||
def test_encode_with_none_response(self):
|
||||
"""Test that None response is handled correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -71,14 +71,14 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - None should not be included
|
||||
assert "response" not in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_response_with_completion_returns_correct_flag(self):
|
||||
"""Test that from_response_with_completion returns correct is_final flag"""
|
||||
def test_encode_with_completion_returns_correct_flag(self):
|
||||
"""Test that encode_with_completion returns correct is_final flag"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
|
|
@ -90,7 +90,7 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response_chunk)
|
||||
result, is_final = translator.encode_with_completion(response_chunk)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
|
|
@ -105,7 +105,7 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(final_response)
|
||||
result, is_final = translator.encode_with_completion(final_response)
|
||||
|
||||
# Assert - is_final is based on end_of_session, not end_of_stream
|
||||
assert is_final is True
|
||||
|
|
@ -116,7 +116,7 @@ class TestGraphRagResponseTranslator:
|
|||
class TestDocumentRagResponseTranslator:
|
||||
"""Test DocumentRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
def test_encode_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
|
@ -127,14 +127,14 @@ class TestDocumentRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
def test_encode_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
|
@ -145,7 +145,7 @@ class TestDocumentRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Document content"
|
||||
|
|
@ -155,7 +155,7 @@ class TestDocumentRagResponseTranslator:
|
|||
class TestPromptResponseTranslator:
|
||||
"""Test PromptResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_text(self):
|
||||
def test_encode_with_empty_text(self):
|
||||
"""Test that empty text strings are preserved"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -167,14 +167,14 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "text" in result
|
||||
assert result["text"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_text(self):
|
||||
def test_encode_with_non_empty_text(self):
|
||||
"""Test that non-empty text works correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -186,13 +186,13 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["text"] == "Some prompt response"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_text(self):
|
||||
def test_encode_with_none_text(self):
|
||||
"""Test that None text is handled correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -204,14 +204,14 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "text" not in result
|
||||
assert "object" in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_includes_end_of_stream(self):
|
||||
def test_encode_includes_end_of_stream(self):
|
||||
"""Test that end_of_stream flag is always included"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -225,7 +225,7 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result
|
||||
|
|
@ -235,7 +235,7 @@ class TestPromptResponseTranslator:
|
|||
class TestTextCompletionResponseTranslator:
|
||||
"""Test TextCompletionResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_always_includes_response(self):
|
||||
def test_encode_always_includes_response(self):
|
||||
"""Test that response field is always included, even if empty"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
|
|
@ -249,13 +249,13 @@ class TestTextCompletionResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - Response should always be present
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
|
||||
def test_from_response_with_completion_with_empty_final(self):
|
||||
def test_encode_with_completion_with_empty_final(self):
|
||||
"""Test that empty final response is handled correctly"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
|
|
@ -269,7 +269,7 @@ class TestTextCompletionResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response)
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True
|
||||
|
|
@ -297,7 +297,7 @@ class TestStreamingProtocolCompliance:
|
|||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
|
||||
|
|
@ -320,7 +320,7 @@ class TestStreamingProtocolCompliance:
|
|||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
|
||||
|
|
|
|||
54
tests/unit/test_gateway/test_text_document_translator.py
Normal file
54
tests/unit/test_gateway/test_text_document_translator.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Unit tests for text document gateway translation compatibility.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from trustgraph.messaging.translators.document_loading import TextDocumentTranslator
|
||||
|
||||
|
||||
class TestTextDocumentTranslator:
|
||||
def test_decode_decodes_base64_text(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "Cancer survival: 2.74× higher hazard ratio"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"id": "doc-1",
|
||||
"user": "alice",
|
||||
"collection": "research",
|
||||
"charset": "utf-8",
|
||||
"text": base64.b64encode(payload.encode("utf-8")).decode("ascii"),
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.metadata.id == "doc-1"
|
||||
assert msg.metadata.user == "alice"
|
||||
assert msg.metadata.collection == "research"
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
def test_decode_accepts_raw_utf8_text(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "Cancer survival: 2.74× higher hazard ratio"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"charset": "utf-8",
|
||||
"text": payload,
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
def test_decode_falls_back_to_raw_non_base64_ascii(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "plain-text payload"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"charset": "utf-8",
|
||||
"text": payload,
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
|
@ -10,16 +10,19 @@ from trustgraph.schema import Triple, Term, IRI, LITERAL
|
|||
from trustgraph.provenance.agent import (
|
||||
agent_session_triples,
|
||||
agent_iteration_triples,
|
||||
agent_observation_triples,
|
||||
agent_final_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_TOOL_USE, TG_SYNTHESIS,
|
||||
TG_AGENT_QUESTION,
|
||||
)
|
||||
|
||||
|
|
@ -63,7 +66,7 @@ class TestAgentSessionTriples:
|
|||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.SESSION_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SESSION_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION)
|
||||
|
||||
|
|
@ -103,6 +106,25 @@ class TestAgentSessionTriples:
|
|||
)
|
||||
assert len(triples) == 6
|
||||
|
||||
def test_session_parent_uri(self):
|
||||
"""Subagent sessions derive from a parent entity (e.g. Decomposition)."""
|
||||
parent = "urn:trustgraph:agent:parent/decompose"
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z",
|
||||
parent_uri=parent,
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == parent
|
||||
|
||||
def test_session_no_parent_uri(self):
|
||||
"""Top-level sessions have no wasDerivedFrom."""
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
|
||||
assert derived is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_iteration_triples
|
||||
|
|
@ -121,19 +143,17 @@ class TestAgentIterationTriples:
|
|||
)
|
||||
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
|
||||
assert has_type(triples, self.ITER_URI, TG_TOOL_USE)
|
||||
|
||||
def test_first_iteration_generated_by_question(self):
|
||||
"""First iteration uses wasGeneratedBy to link to question activity."""
|
||||
def test_first_iteration_derived_from_question(self):
|
||||
"""First iteration uses wasDerivedFrom to link to question entity."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
# Should NOT have wasDerivedFrom
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is None
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.SESSION_URI
|
||||
|
||||
def test_subsequent_iteration_derived_from_previous(self):
|
||||
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
|
||||
|
|
@ -144,9 +164,6 @@ class TestAgentIterationTriples:
|
|||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
# Should NOT have wasGeneratedBy
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_iteration_label_includes_action(self):
|
||||
triples = agent_iteration_triples(
|
||||
|
|
@ -174,40 +191,24 @@ class TestAgentIterationTriples:
|
|||
# Thought has correct types
|
||||
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
|
||||
# Thought was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Thought was derived from iteration
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, thought_uri)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.ITER_URI
|
||||
# Thought has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == thought_doc
|
||||
|
||||
def test_iteration_observation_sub_entity(self):
|
||||
"""Observation is a sub-entity with Reflection and Observation types."""
|
||||
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
|
||||
obs_doc = "urn:doc:obs-1"
|
||||
def test_iteration_no_observation_sub_entity(self):
|
||||
"""Iteration no longer embeds observation — it's a separate entity."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
observation_uri=obs_uri,
|
||||
observation_document_id=obs_doc,
|
||||
)
|
||||
# Iteration links to observation sub-entity
|
||||
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs_link is not None
|
||||
assert obs_link.o.iri == obs_uri
|
||||
# Observation has correct types
|
||||
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
|
||||
# Observation was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Observation has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == obs_doc
|
||||
# No TG_OBSERVATION predicate on the iteration
|
||||
for t in triples:
|
||||
assert "observation" not in t.p.iri.lower() or "Observation" not in t.p.iri
|
||||
|
||||
def test_iteration_action_recorded(self):
|
||||
triples = agent_iteration_triples(
|
||||
|
|
@ -240,19 +241,17 @@ class TestAgentIterationTriples:
|
|||
parsed = json.loads(arguments.o.value)
|
||||
assert parsed == {}
|
||||
|
||||
def test_iteration_no_thought_or_observation(self):
|
||||
"""Minimal iteration with just action — no thought or observation triples."""
|
||||
def test_iteration_no_thought(self):
|
||||
"""Minimal iteration with just action — no thought triples."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="noop",
|
||||
)
|
||||
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert thought is None
|
||||
assert obs is None
|
||||
|
||||
def test_iteration_chaining(self):
|
||||
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
|
||||
"""Both first and second iterations use wasDerivedFrom."""
|
||||
iter1_uri = "urn:trustgraph:agent:sess/i1"
|
||||
iter2_uri = "urn:trustgraph:agent:sess/i2"
|
||||
|
||||
|
|
@ -263,13 +262,62 @@ class TestAgentIterationTriples:
|
|||
iter2_uri, previous_uri=iter1_uri, action="step2",
|
||||
)
|
||||
|
||||
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
|
||||
assert gen1.o.iri == self.SESSION_URI
|
||||
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
|
||||
assert derived1.o.iri == self.SESSION_URI
|
||||
|
||||
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
|
||||
assert derived2.o.iri == iter1_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_observation_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentObservationTriples:
|
||||
|
||||
OBS_URI = "urn:trustgraph:agent:test-session/i1/observation"
|
||||
ITER_URI = "urn:trustgraph:agent:test-session/i1"
|
||||
|
||||
def test_observation_types(self):
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI,
|
||||
)
|
||||
assert has_type(triples, self.OBS_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.OBS_URI, TG_OBSERVATION_TYPE)
|
||||
|
||||
def test_observation_derived_from_iteration(self):
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI,
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.OBS_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.ITER_URI
|
||||
|
||||
def test_observation_label(self):
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI,
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.OBS_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Observation"
|
||||
|
||||
def test_observation_document(self):
|
||||
doc_id = "urn:doc:obs-1"
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI, document_id=doc_id,
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == doc_id
|
||||
|
||||
def test_observation_no_document(self):
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI,
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI)
|
||||
assert doc is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_final_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -296,19 +344,15 @@ class TestAgentFinalTriples:
|
|||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_final_generated_by_question_when_no_iterations(self):
|
||||
"""When agent answers immediately, final uses wasGeneratedBy."""
|
||||
def test_final_derived_from_question_when_no_iterations(self):
|
||||
"""When agent answers immediately, final uses wasDerivedFrom to question."""
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, question_uri=self.SESSION_URI,
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is None
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.SESSION_URI
|
||||
|
||||
def test_final_label(self):
|
||||
triples = agent_final_triples(
|
||||
|
|
@ -334,3 +378,59 @@ class TestAgentFinalTriples:
|
|||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert doc is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_synthesis_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentSynthesisTriples:
|
||||
|
||||
SYNTH_URI = "urn:trustgraph:agent:test-session/synthesis"
|
||||
FINDING_0 = "urn:trustgraph:agent:test-session/finding/0"
|
||||
FINDING_1 = "urn:trustgraph:agent:test-session/finding/1"
|
||||
FINDING_2 = "urn:trustgraph:agent:test-session/finding/2"
|
||||
|
||||
def test_synthesis_types(self):
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
assert has_type(triples, self.SYNTH_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SYNTH_URI, TG_SYNTHESIS)
|
||||
assert has_type(triples, self.SYNTH_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_synthesis_single_parent_string(self):
|
||||
"""Single parent passed as string."""
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 1
|
||||
assert derived[0].o.iri == self.FINDING_0
|
||||
|
||||
def test_synthesis_multiple_parents(self):
|
||||
"""Multiple parents for supervisor fan-in."""
|
||||
parents = [self.FINDING_0, self.FINDING_1, self.FINDING_2]
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, parents)
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 3
|
||||
derived_uris = {t.o.iri for t in derived}
|
||||
assert derived_uris == set(parents)
|
||||
|
||||
def test_synthesis_single_parent_as_list(self):
|
||||
"""Single parent passed as list."""
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, [self.FINDING_0])
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 1
|
||||
assert derived[0].o.iri == self.FINDING_0
|
||||
|
||||
def test_synthesis_document(self):
|
||||
triples = agent_synthesis_triples(
|
||||
self.SYNTH_URI, self.FINDING_0,
|
||||
document_id="urn:doc:synth",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYNTH_URI)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == "urn:doc:synth"
|
||||
|
||||
def test_synthesis_label(self):
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
label = find_triple(triples, RDFS_LABEL, self.SYNTH_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Synthesis"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.api.explainability import (
|
|||
Synthesis,
|
||||
Reflection,
|
||||
Analysis,
|
||||
Observation,
|
||||
Conclusion,
|
||||
parse_edge_selection_triples,
|
||||
extract_term_value,
|
||||
|
|
@ -23,12 +24,12 @@ from trustgraph.api.explainability import (
|
|||
ExplainabilityClient,
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
|
||||
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
|
||||
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM,
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
)
|
||||
|
||||
|
|
@ -180,14 +181,30 @@ class TestExplainEntityFromTriples:
|
|||
("urn:ana:1", TG_ACTION, "graph-rag-query"),
|
||||
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
|
||||
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
|
||||
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ana:1", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
assert entity.action == "graph-rag-query"
|
||||
assert entity.arguments == '{"query": "test"}'
|
||||
assert entity.thought == "urn:ref:thought-1"
|
||||
assert entity.observation == "urn:ref:obs-1"
|
||||
|
||||
def test_observation(self):
|
||||
triples = [
|
||||
("urn:obs:1", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
("urn:obs:1", TG_DOCUMENT, "urn:doc:obs-content"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:obs:1", triples)
|
||||
assert isinstance(entity, Observation)
|
||||
assert entity.document == "urn:doc:obs-content"
|
||||
assert entity.entity_type == "observation"
|
||||
|
||||
def test_observation_no_document(self):
|
||||
triples = [
|
||||
("urn:obs:2", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:obs:2", triples)
|
||||
assert isinstance(entity, Observation)
|
||||
assert entity.document == ""
|
||||
|
||||
def test_conclusion_with_document(self):
|
||||
triples = [
|
||||
|
|
@ -541,3 +558,96 @@ class TestExplainabilityClientDetectSessionType:
|
|||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"
|
||||
|
||||
|
||||
class TestChainWalkerFollowsSubTraceTerminal:
|
||||
"""Test that _follow_provenance_chain continues from a sub-trace's
|
||||
Synthesis to find downstream entities like Observation."""
|
||||
|
||||
def test_observation_found_via_subtrace_synthesis(self):
|
||||
"""
|
||||
DAG: Question -> Analysis -> GraphRAG Question -> Synthesis -> Observation
|
||||
The walker should find Analysis, the sub-trace, then follow from
|
||||
Synthesis to discover Observation.
|
||||
"""
|
||||
# Entity triples (s, p, o)
|
||||
entity_data = {
|
||||
"urn:agent:q": [
|
||||
("urn:agent:q", RDF_TYPE, TG_AGENT_QUESTION),
|
||||
("urn:agent:q", TG_QUERY, "test"),
|
||||
],
|
||||
"urn:agent:analysis": [
|
||||
("urn:agent:analysis", RDF_TYPE, TG_ANALYSIS),
|
||||
("urn:agent:analysis", PROV_WAS_DERIVED_FROM, "urn:agent:q"),
|
||||
],
|
||||
"urn:graphrag:q": [
|
||||
("urn:graphrag:q", RDF_TYPE, TG_QUESTION),
|
||||
("urn:graphrag:q", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
|
||||
("urn:graphrag:q", TG_QUERY, "test"),
|
||||
("urn:graphrag:q", PROV_WAS_DERIVED_FROM, "urn:agent:analysis"),
|
||||
],
|
||||
"urn:graphrag:synth": [
|
||||
("urn:graphrag:synth", RDF_TYPE, TG_SYNTHESIS),
|
||||
("urn:graphrag:synth", PROV_WAS_DERIVED_FROM, "urn:graphrag:q"),
|
||||
],
|
||||
"urn:agent:obs": [
|
||||
("urn:agent:obs", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
("urn:agent:obs", PROV_WAS_DERIVED_FROM, "urn:graphrag:synth"),
|
||||
],
|
||||
"urn:agent:conclusion": [
|
||||
("urn:agent:conclusion", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:agent:conclusion", PROV_WAS_DERIVED_FROM, "urn:agent:obs"),
|
||||
],
|
||||
}
|
||||
|
||||
# Build a mock flow that answers triples queries
|
||||
# Query by s= returns that entity's triples
|
||||
# Query by p=wasDerivedFrom, o=X returns entities derived from X
|
||||
def mock_triples_query(s=None, p=None, o=None, **kwargs):
|
||||
if s and not p:
|
||||
# Fetch entity triples
|
||||
tuples = entity_data.get(s, [])
|
||||
return _make_wire_triples(tuples)
|
||||
elif p == PROV_WAS_DERIVED_FROM and o:
|
||||
# Find entities derived from o
|
||||
results = []
|
||||
for uri, tuples in entity_data.items():
|
||||
for _, pred, obj in tuples:
|
||||
if pred == PROV_WAS_DERIVED_FROM and obj == o:
|
||||
results.append((uri, pred, obj))
|
||||
return _make_wire_triples(results)
|
||||
return []
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.triples_query.side_effect = mock_triples_query
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
|
||||
|
||||
# Mock fetch_graphrag_trace to return a trace with a synthesis
|
||||
synth_entity = Synthesis(uri="urn:graphrag:synth", entity_type="synthesis")
|
||||
client.fetch_graphrag_trace = MagicMock(return_value={
|
||||
"question": Question(uri="urn:graphrag:q", entity_type="question",
|
||||
question_type="graph-rag"),
|
||||
"synthesis": synth_entity,
|
||||
})
|
||||
|
||||
trace = client.fetch_agent_trace(
|
||||
"urn:agent:q",
|
||||
graph="urn:graph:retrieval",
|
||||
)
|
||||
|
||||
# Should have found all steps
|
||||
step_types = [
|
||||
type(s).__name__ if not isinstance(s, dict) else s.get("type")
|
||||
for s in trace["steps"]
|
||||
]
|
||||
|
||||
assert "Analysis" in step_types, f"Missing Analysis in {step_types}"
|
||||
assert "sub-trace" in step_types, f"Missing sub-trace in {step_types}"
|
||||
assert "Observation" in step_types, f"Missing Observation in {step_types}"
|
||||
assert "Conclusion" in step_types, f"Missing Conclusion in {step_types}"
|
||||
|
||||
# Observation should come after the sub-trace
|
||||
subtrace_idx = step_types.index("sub-trace")
|
||||
obs_idx = step_types.index("Observation")
|
||||
assert obs_idx > subtrace_idx, "Observation should appear after sub-trace"
|
||||
|
|
|
|||
295
tests/unit/test_provenance/test_graph_rag_chain.py
Normal file
295
tests/unit/test_provenance/test_graph_rag_chain.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
"""
|
||||
Structural test for the graph-rag provenance chain.
|
||||
|
||||
Verifies that a complete graph-rag query produces the expected
|
||||
provenance chain:
|
||||
|
||||
question → grounding → exploration → focus → synthesis
|
||||
|
||||
Each step must:
|
||||
- Have the correct rdf:type
|
||||
- Link to its predecessor via prov:wasDerivedFrom
|
||||
- Carry expected domain-specific data
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.provenance.triples import (
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
)
|
||||
from trustgraph.provenance.uris import (
|
||||
question_uri,
|
||||
grounding_uri,
|
||||
exploration_uri,
|
||||
focus_uri,
|
||||
synthesis_uri,
|
||||
)
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_GRAPH_RAG_QUESTION, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT,
|
||||
PROV_STARTED_AT_TIME,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SESSION_ID = "test-session-1234"
|
||||
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
"""Find first triple matching predicate (and optionally subject)."""
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
"""Find all triples matching predicate (and optionally subject)."""
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
"""Check if subject has the given rdf:type."""
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
"""Get the wasDerivedFrom target URI for a subject."""
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build the full chain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def chain():
|
||||
"""Build all provenance triples for a complete graph-rag query."""
|
||||
q_uri = question_uri(SESSION_ID)
|
||||
gnd_uri = grounding_uri(SESSION_ID)
|
||||
exp_uri = exploration_uri(SESSION_ID)
|
||||
foc_uri = focus_uri(SESSION_ID)
|
||||
syn_uri = synthesis_uri(SESSION_ID)
|
||||
|
||||
q = question_triples(q_uri, "What is quantum computing?", "2026-01-01T00:00:00Z")
|
||||
gnd = grounding_triples(gnd_uri, q_uri, ["quantum", "computing"])
|
||||
exp = exploration_triples(
|
||||
exp_uri, gnd_uri, edge_count=42,
|
||||
entities=["urn:entity:1", "urn:entity:2"],
|
||||
)
|
||||
foc = focus_triples(
|
||||
foc_uri, exp_uri,
|
||||
selected_edges_with_reasoning=[
|
||||
{
|
||||
"edge": (
|
||||
"http://example.com/QuantumComputing",
|
||||
"http://schema.org/relatedTo",
|
||||
"http://example.com/Physics",
|
||||
),
|
||||
"reasoning": "Directly relevant to the query",
|
||||
},
|
||||
{
|
||||
"edge": (
|
||||
"http://example.com/QuantumComputing",
|
||||
"http://schema.org/name",
|
||||
"Quantum Computing",
|
||||
),
|
||||
"reasoning": "Provides the entity label",
|
||||
},
|
||||
],
|
||||
session_id=SESSION_ID,
|
||||
)
|
||||
syn = synthesis_triples(syn_uri, foc_uri, document_id="urn:doc:answer-1")
|
||||
|
||||
return {
|
||||
"uris": {
|
||||
"question": q_uri,
|
||||
"grounding": gnd_uri,
|
||||
"exploration": exp_uri,
|
||||
"focus": foc_uri,
|
||||
"synthesis": syn_uri,
|
||||
},
|
||||
"triples": {
|
||||
"question": q,
|
||||
"grounding": gnd,
|
||||
"exploration": exp,
|
||||
"focus": foc,
|
||||
"synthesis": syn,
|
||||
},
|
||||
"all": q + gnd + exp + foc + syn,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chain structure tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagProvenanceChain:
|
||||
"""Verify the full question → grounding → exploration → focus → synthesis chain."""
|
||||
|
||||
def test_chain_has_five_stages(self, chain):
|
||||
"""Each stage should produce at least some triples."""
|
||||
for stage in ["question", "grounding", "exploration", "focus", "synthesis"]:
|
||||
assert len(chain["triples"][stage]) > 0, f"{stage} produced no triples"
|
||||
|
||||
def test_derivation_chain(self, chain):
|
||||
"""
|
||||
The wasDerivedFrom links must form:
|
||||
grounding → question, exploration → grounding,
|
||||
focus → exploration, synthesis → focus.
|
||||
"""
|
||||
uris = chain["uris"]
|
||||
all_triples = chain["all"]
|
||||
|
||||
assert derived_from(all_triples, uris["grounding"]) == uris["question"]
|
||||
assert derived_from(all_triples, uris["exploration"]) == uris["grounding"]
|
||||
assert derived_from(all_triples, uris["focus"]) == uris["exploration"]
|
||||
assert derived_from(all_triples, uris["synthesis"]) == uris["focus"]
|
||||
|
||||
def test_question_has_no_parent(self, chain):
|
||||
"""The root question should not derive from anything (no parent_uri)."""
|
||||
uris = chain["uris"]
|
||||
all_triples = chain["all"]
|
||||
assert derived_from(all_triples, uris["question"]) is None
|
||||
|
||||
def test_question_with_parent(self):
|
||||
"""When a parent_uri is given, question should derive from it."""
|
||||
q_uri = question_uri("child-session")
|
||||
parent = "urn:trustgraph:agent:iteration:parent"
|
||||
q = question_triples(q_uri, "sub-query", "2026-01-01T00:00:00Z",
|
||||
parent_uri=parent)
|
||||
assert derived_from(q, q_uri) == parent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type annotation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagProvenanceTypes:
|
||||
"""Each stage must have the correct rdf:type annotations."""
|
||||
|
||||
def test_question_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["question"]
|
||||
assert has_type(triples, uris["question"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["question"], TG_GRAPH_RAG_QUESTION)
|
||||
|
||||
def test_grounding_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["grounding"]
|
||||
assert has_type(triples, uris["grounding"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["grounding"], TG_GROUNDING)
|
||||
|
||||
def test_exploration_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["exploration"]
|
||||
assert has_type(triples, uris["exploration"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["exploration"], TG_EXPLORATION)
|
||||
|
||||
def test_focus_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["focus"]
|
||||
assert has_type(triples, uris["focus"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["focus"], TG_FOCUS)
|
||||
|
||||
def test_synthesis_types(self, chain):
|
||||
uris = chain["uris"]
|
||||
triples = chain["triples"]["synthesis"]
|
||||
assert has_type(triples, uris["synthesis"], PROV_ENTITY)
|
||||
assert has_type(triples, uris["synthesis"], TG_SYNTHESIS)
|
||||
assert has_type(triples, uris["synthesis"], TG_ANSWER_TYPE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domain-specific content tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagProvenanceContent:
|
||||
"""Each stage should carry the expected domain data."""
|
||||
|
||||
def test_question_has_query_text(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["question"], TG_QUERY, uris["question"])
|
||||
assert t is not None
|
||||
assert t.o.value == "What is quantum computing?"
|
||||
|
||||
def test_question_has_timestamp(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["question"], PROV_STARTED_AT_TIME, uris["question"])
|
||||
assert t is not None
|
||||
assert t.o.value == "2026-01-01T00:00:00Z"
|
||||
|
||||
def test_grounding_has_concepts(self, chain):
|
||||
uris = chain["uris"]
|
||||
concepts = find_triples(chain["triples"]["grounding"], TG_CONCEPT, uris["grounding"])
|
||||
concept_values = {t.o.value for t in concepts}
|
||||
assert concept_values == {"quantum", "computing"}
|
||||
|
||||
def test_exploration_has_edge_count(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["exploration"], TG_EDGE_COUNT, uris["exploration"])
|
||||
assert t is not None
|
||||
assert t.o.value == "42"
|
||||
|
||||
def test_exploration_has_entities(self, chain):
|
||||
uris = chain["uris"]
|
||||
entities = find_triples(chain["triples"]["exploration"], TG_ENTITY, uris["exploration"])
|
||||
entity_iris = {t.o.iri for t in entities}
|
||||
assert entity_iris == {"urn:entity:1", "urn:entity:2"}
|
||||
|
||||
def test_focus_has_selected_edges(self, chain):
|
||||
uris = chain["uris"]
|
||||
edges = find_triples(chain["triples"]["focus"], TG_SELECTED_EDGE, uris["focus"])
|
||||
assert len(edges) == 2
|
||||
|
||||
def test_focus_edges_have_quoted_triples(self, chain):
|
||||
"""Each edge selection entity should have a tg:edge with a quoted triple."""
|
||||
focus = chain["triples"]["focus"]
|
||||
edge_triples = find_triples(focus, TG_EDGE)
|
||||
assert len(edge_triples) == 2
|
||||
|
||||
# Each should have a quoted triple as the object
|
||||
for t in edge_triples:
|
||||
assert t.o.triple is not None, "tg:edge object should be a quoted triple"
|
||||
|
||||
def test_focus_edges_have_reasoning(self, chain):
|
||||
"""Each edge selection entity should have tg:reasoning."""
|
||||
focus = chain["triples"]["focus"]
|
||||
reasoning = find_triples(focus, TG_REASONING)
|
||||
assert len(reasoning) == 2
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert "Directly relevant to the query" in reasoning_texts
|
||||
assert "Provides the entity label" in reasoning_texts
|
||||
|
||||
def test_synthesis_has_document_ref(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["synthesis"], TG_DOCUMENT, uris["synthesis"])
|
||||
assert t is not None
|
||||
assert t.o.iri == "urn:doc:answer-1"
|
||||
|
||||
def test_synthesis_has_labels(self, chain):
|
||||
uris = chain["uris"]
|
||||
t = find_triple(chain["triples"]["synthesis"], RDFS_LABEL, uris["synthesis"])
|
||||
assert t is not None
|
||||
assert t.o.value == "Synthesis"
|
||||
|
|
@ -500,7 +500,7 @@ class TestQuestionTriples:
|
|||
|
||||
def test_question_types(self):
|
||||
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
|
||||
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.Q_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.Q_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION)
|
||||
|
||||
|
|
@ -543,11 +543,11 @@ class TestGroundingTriples:
|
|||
assert has_type(triples, self.GND_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.GND_URI, TG_GROUNDING)
|
||||
|
||||
def test_grounding_generated_by_question(self):
|
||||
def test_grounding_derived_from_question(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.Q_URI
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.GND_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.Q_URI
|
||||
|
||||
def test_grounding_concepts(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
|
||||
|
|
@ -730,7 +730,7 @@ class TestDocRagQuestionTriples:
|
|||
|
||||
def test_docrag_question_types(self):
|
||||
triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z")
|
||||
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.Q_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.Q_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION)
|
||||
|
||||
|
|
|
|||
164
tests/unit/test_pubsub/test_queue_naming.py
Normal file
164
tests/unit/test_pubsub/test_queue_naming.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Tests for queue naming and topic mapping.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import argparse
|
||||
|
||||
from trustgraph.schema.core.topic import queue
|
||||
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
|
||||
from trustgraph.base.pulsar_backend import PulsarBackend
|
||||
|
||||
|
||||
class TestQueueFunction:
|
||||
|
||||
def test_flow_default(self):
|
||||
assert queue('text-completion-request') == 'flow:tg:text-completion-request'
|
||||
|
||||
def test_request_class(self):
|
||||
assert queue('config', cls='request') == 'request:tg:config'
|
||||
|
||||
def test_response_class(self):
|
||||
assert queue('config', cls='response') == 'response:tg:config'
|
||||
|
||||
def test_notify_class(self):
|
||||
assert queue('config', cls='notify') == 'notify:tg:config'
|
||||
|
||||
def test_custom_topicspace(self):
|
||||
assert queue('config', cls='request', topicspace='prod') == 'request:prod:config'
|
||||
|
||||
def test_default_class_is_flow(self):
|
||||
result = queue('something')
|
||||
assert result.startswith('flow:')
|
||||
|
||||
|
||||
class TestPulsarMapTopic:
|
||||
|
||||
@pytest.fixture
|
||||
def backend(self):
|
||||
"""Create a PulsarBackend without connecting."""
|
||||
b = object.__new__(PulsarBackend)
|
||||
return b
|
||||
|
||||
def test_flow_maps_to_persistent(self, backend):
|
||||
assert backend.map_topic('flow:tg:text-completion-request') == \
|
||||
'persistent://tg/flow/text-completion-request'
|
||||
|
||||
def test_notify_maps_to_non_persistent(self, backend):
|
||||
assert backend.map_topic('notify:tg:config') == \
|
||||
'non-persistent://tg/notify/config'
|
||||
|
||||
def test_request_maps_to_non_persistent(self, backend):
|
||||
assert backend.map_topic('request:tg:config') == \
|
||||
'non-persistent://tg/request/config'
|
||||
|
||||
def test_response_maps_to_non_persistent(self, backend):
|
||||
assert backend.map_topic('response:tg:librarian') == \
|
||||
'non-persistent://tg/response/librarian'
|
||||
|
||||
def test_passthrough_pulsar_uri(self, backend):
|
||||
uri = 'persistent://tg/flow/something'
|
||||
assert backend.map_topic(uri) == uri
|
||||
|
||||
def test_invalid_format_raises(self, backend):
|
||||
with pytest.raises(ValueError, match="Invalid queue format"):
|
||||
backend.map_topic('bad-format')
|
||||
|
||||
def test_invalid_class_raises(self, backend):
|
||||
with pytest.raises(ValueError, match="Invalid queue class"):
|
||||
backend.map_topic('unknown:tg:topic')
|
||||
|
||||
def test_custom_topicspace(self, backend):
|
||||
assert backend.map_topic('flow:prod:my-queue') == \
|
||||
'persistent://prod/flow/my-queue'
|
||||
|
||||
|
||||
class TestGetPubsubDispatch:
|
||||
|
||||
def test_unknown_backend_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown pub/sub backend"):
|
||||
get_pubsub(pubsub_backend='redis')
|
||||
|
||||
|
||||
class TestAddPubsubArgs:
|
||||
|
||||
def test_standalone_defaults_to_localhost(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=True)
|
||||
args = parser.parse_args([])
|
||||
assert args.pulsar_host == 'pulsar://localhost:6650'
|
||||
assert args.pulsar_listener == 'localhost'
|
||||
|
||||
def test_non_standalone_defaults_to_container(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=False)
|
||||
args = parser.parse_args([])
|
||||
assert 'pulsar:6650' in args.pulsar_host
|
||||
assert args.pulsar_listener is None
|
||||
|
||||
def test_cli_override_respected(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=True)
|
||||
args = parser.parse_args(['--pulsar-host', 'pulsar://custom:6650'])
|
||||
assert args.pulsar_host == 'pulsar://custom:6650'
|
||||
|
||||
def test_pubsub_backend_default(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([])
|
||||
assert args.pubsub_backend == 'pulsar'
|
||||
|
||||
|
||||
class TestAddPubsubArgsRabbitMQ:
|
||||
|
||||
def test_rabbitmq_args_present(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([
|
||||
'--pubsub-backend', 'rabbitmq',
|
||||
'--rabbitmq-host', 'myhost',
|
||||
'--rabbitmq-port', '5673',
|
||||
])
|
||||
assert args.pubsub_backend == 'rabbitmq'
|
||||
assert args.rabbitmq_host == 'myhost'
|
||||
assert args.rabbitmq_port == 5673
|
||||
|
||||
def test_rabbitmq_defaults_container(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([])
|
||||
assert args.rabbitmq_host == 'rabbitmq'
|
||||
assert args.rabbitmq_port == 5672
|
||||
assert args.rabbitmq_username == 'guest'
|
||||
assert args.rabbitmq_password == 'guest'
|
||||
assert args.rabbitmq_vhost == '/'
|
||||
|
||||
def test_rabbitmq_standalone_defaults_to_localhost(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=True)
|
||||
args = parser.parse_args([])
|
||||
assert args.rabbitmq_host == 'localhost'
|
||||
|
||||
|
||||
class TestQueueDefinitions:
|
||||
"""Verify the actual queue constants produce correct names."""
|
||||
|
||||
def test_config_request(self):
|
||||
from trustgraph.schema.services.config import config_request_queue
|
||||
assert config_request_queue == 'request:tg:config'
|
||||
|
||||
def test_config_response(self):
|
||||
from trustgraph.schema.services.config import config_response_queue
|
||||
assert config_response_queue == 'response:tg:config'
|
||||
|
||||
def test_config_push(self):
|
||||
from trustgraph.schema.services.config import config_push_queue
|
||||
assert config_push_queue == 'notify:tg:config'
|
||||
|
||||
def test_librarian_request(self):
|
||||
from trustgraph.schema.services.library import librarian_request_queue
|
||||
assert librarian_request_queue == 'request:tg:librarian'
|
||||
|
||||
def test_knowledge_request(self):
|
||||
from trustgraph.schema.knowledge.knowledge import knowledge_request_queue
|
||||
assert knowledge_request_queue == 'request:tg:knowledge'
|
||||
107
tests/unit/test_pubsub/test_rabbitmq_backend.py
Normal file
107
tests/unit/test_pubsub/test_rabbitmq_backend.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""
|
||||
Unit tests for RabbitMQ backend — queue name mapping and factory dispatch.
|
||||
Does not require a running RabbitMQ instance.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import argparse
|
||||
|
||||
pika = pytest.importorskip("pika", reason="pika not installed")
|
||||
|
||||
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
|
||||
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
|
||||
|
||||
|
||||
class TestRabbitMQMapQueueName:
|
||||
|
||||
@pytest.fixture
|
||||
def backend(self):
|
||||
b = object.__new__(RabbitMQBackend)
|
||||
return b
|
||||
|
||||
def test_flow_is_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
|
||||
assert durable is True
|
||||
assert name == 'tg.flow.text-completion-request'
|
||||
|
||||
def test_notify_is_not_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('notify:tg:config')
|
||||
assert durable is False
|
||||
assert name == 'tg.notify.config'
|
||||
|
||||
def test_request_is_not_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('request:tg:config')
|
||||
assert durable is False
|
||||
assert name == 'tg.request.config'
|
||||
|
||||
def test_response_is_not_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('response:tg:librarian')
|
||||
assert durable is False
|
||||
assert name == 'tg.response.librarian'
|
||||
|
||||
def test_custom_topicspace(self, backend):
|
||||
name, durable = backend.map_queue_name('flow:prod:my-queue')
|
||||
assert name == 'prod.flow.my-queue'
|
||||
assert durable is True
|
||||
|
||||
def test_no_colon_defaults_to_flow(self, backend):
|
||||
name, durable = backend.map_queue_name('simple-queue')
|
||||
assert name == 'tg.simple-queue'
|
||||
assert durable is False
|
||||
|
||||
def test_invalid_class_raises(self, backend):
|
||||
with pytest.raises(ValueError, match="Invalid queue class"):
|
||||
backend.map_queue_name('unknown:tg:topic')
|
||||
|
||||
def test_flow_with_flow_suffix(self, backend):
|
||||
"""Queue names with flow suffix (e.g. :default) are preserved."""
|
||||
name, durable = backend.map_queue_name('request:tg:prompt:default')
|
||||
assert name == 'tg.request.prompt:default'
|
||||
|
||||
|
||||
class TestGetPubsubRabbitMQ:
|
||||
|
||||
def test_factory_creates_rabbitmq_backend(self):
|
||||
backend = get_pubsub(pubsub_backend='rabbitmq')
|
||||
assert isinstance(backend, RabbitMQBackend)
|
||||
|
||||
def test_factory_passes_config(self):
|
||||
backend = get_pubsub(
|
||||
pubsub_backend='rabbitmq',
|
||||
rabbitmq_host='myhost',
|
||||
rabbitmq_port=5673,
|
||||
rabbitmq_username='user',
|
||||
rabbitmq_password='pass',
|
||||
rabbitmq_vhost='/test',
|
||||
)
|
||||
assert isinstance(backend, RabbitMQBackend)
|
||||
# Verify connection params were set
|
||||
params = backend._connection_params
|
||||
assert params.host == 'myhost'
|
||||
assert params.port == 5673
|
||||
assert params.virtual_host == '/test'
|
||||
|
||||
|
||||
class TestAddPubsubArgsRabbitMQ:
|
||||
|
||||
def test_rabbitmq_args_present(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([
|
||||
'--pubsub-backend', 'rabbitmq',
|
||||
'--rabbitmq-host', 'myhost',
|
||||
'--rabbitmq-port', '5673',
|
||||
])
|
||||
assert args.pubsub_backend == 'rabbitmq'
|
||||
assert args.rabbitmq_host == 'myhost'
|
||||
assert args.rabbitmq_port == 5673
|
||||
|
||||
def test_rabbitmq_defaults_container(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([])
|
||||
assert args.rabbitmq_host == 'rabbitmq'
|
||||
assert args.rabbitmq_port == 5672
|
||||
assert args.rabbitmq_username == 'guest'
|
||||
assert args.rabbitmq_password == 'guest'
|
||||
assert args.rabbitmq_vhost == '/'
|
||||
424
tests/unit/test_query/test_sparql_expressions.py
Normal file
424
tests/unit/test_query/test_sparql_expressions.py
Normal file
|
|
@ -0,0 +1,424 @@
|
|||
"""
|
||||
Tests for SPARQL FILTER expression evaluator.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.schema import Term, IRI, LITERAL, BLANK
|
||||
from trustgraph.query.sparql.expressions import (
|
||||
evaluate_expression, _effective_boolean, _to_string, _to_numeric,
|
||||
_comparable_value,
|
||||
)
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def iri(v):
|
||||
return Term(type=IRI, iri=v)
|
||||
|
||||
def lit(v, datatype="", language=""):
|
||||
return Term(type=LITERAL, value=v, datatype=datatype, language=language)
|
||||
|
||||
def blank(v):
|
||||
return Term(type=BLANK, id=v)
|
||||
|
||||
XSD = "http://www.w3.org/2001/XMLSchema#"
|
||||
|
||||
|
||||
class TestEvaluateExpression:
|
||||
"""Test expression evaluation with rdflib algebra nodes."""
|
||||
|
||||
def test_variable_bound(self):
|
||||
from rdflib.term import Variable
|
||||
result = evaluate_expression(Variable("x"), {"x": lit("hello")})
|
||||
assert result.value == "hello"
|
||||
|
||||
def test_variable_unbound(self):
|
||||
from rdflib.term import Variable
|
||||
result = evaluate_expression(Variable("x"), {})
|
||||
assert result is None
|
||||
|
||||
def test_uriref_constant(self):
|
||||
from rdflib import URIRef
|
||||
result = evaluate_expression(
|
||||
URIRef("http://example.com/a"), {}
|
||||
)
|
||||
assert result.type == IRI
|
||||
assert result.iri == "http://example.com/a"
|
||||
|
||||
def test_literal_constant(self):
|
||||
from rdflib import Literal
|
||||
result = evaluate_expression(Literal("hello"), {})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "hello"
|
||||
|
||||
def test_boolean_constant(self):
|
||||
assert evaluate_expression(True, {}) is True
|
||||
assert evaluate_expression(False, {}) is False
|
||||
|
||||
def test_numeric_constant(self):
|
||||
assert evaluate_expression(42, {}) == 42
|
||||
assert evaluate_expression(3.14, {}) == 3.14
|
||||
|
||||
def test_none_returns_true(self):
|
||||
assert evaluate_expression(None, {}) is True
|
||||
|
||||
|
||||
class TestRelationalExpressions:
|
||||
"""Test comparison operators via CompValue nodes."""
|
||||
|
||||
def _make_relational(self, left, op, right):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue("RelationalExpression",
|
||||
expr=left, op=op, other=right)
|
||||
|
||||
def test_equal_literals(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Literal("a"), "=", Literal("a"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_not_equal_literals(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Literal("a"), "!=", Literal("b"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_less_than(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Literal("a"), "<", Literal("b"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_greater_than(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Literal("b"), ">", Literal("a"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_equal_with_variables(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_relational(Variable("x"), "=", Variable("y"))
|
||||
sol = {"x": lit("same"), "y": lit("same")}
|
||||
assert evaluate_expression(expr, sol) is True
|
||||
|
||||
def test_unequal_with_variables(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_relational(Variable("x"), "=", Variable("y"))
|
||||
sol = {"x": lit("one"), "y": lit("two")}
|
||||
assert evaluate_expression(expr, sol) is False
|
||||
|
||||
def test_none_operand_returns_false(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Variable("x"), "=", Literal("a"))
|
||||
assert evaluate_expression(expr, {}) is False
|
||||
|
||||
|
||||
class TestLogicalExpressions:
|
||||
|
||||
def _make_and(self, exprs):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue("ConditionalAndExpression",
|
||||
expr=exprs[0], other=exprs[1:])
|
||||
|
||||
def _make_or(self, exprs):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue("ConditionalOrExpression",
|
||||
expr=exprs[0], other=exprs[1:])
|
||||
|
||||
def _make_not(self, expr):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue("UnaryNot", expr=expr)
|
||||
|
||||
def test_and_true_true(self):
|
||||
result = evaluate_expression(self._make_and([True, True]), {})
|
||||
assert result is True
|
||||
|
||||
def test_and_true_false(self):
|
||||
result = evaluate_expression(self._make_and([True, False]), {})
|
||||
assert result is False
|
||||
|
||||
def test_or_false_true(self):
|
||||
result = evaluate_expression(self._make_or([False, True]), {})
|
||||
assert result is True
|
||||
|
||||
def test_or_false_false(self):
|
||||
result = evaluate_expression(self._make_or([False, False]), {})
|
||||
assert result is False
|
||||
|
||||
def test_not_true(self):
|
||||
result = evaluate_expression(self._make_not(True), {})
|
||||
assert result is False
|
||||
|
||||
def test_not_false(self):
|
||||
result = evaluate_expression(self._make_not(False), {})
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestBuiltinFunctions:
|
||||
|
||||
def _make_builtin(self, name, **kwargs):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue(f"Builtin_{name}", **kwargs)
|
||||
|
||||
def test_bound_true(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("BOUND", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("hi")}) is True
|
||||
|
||||
def test_bound_false(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("BOUND", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {}) is False
|
||||
|
||||
def test_isiri_true(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isIRI", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": iri("http://x")}) is True
|
||||
|
||||
def test_isiri_false(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isIRI", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
||||
|
||||
def test_isliteral_true(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isLITERAL", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_isliteral_false(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isLITERAL", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": iri("http://x")}) is False
|
||||
|
||||
def test_isblank_true(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isBLANK", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": blank("b1")}) is True
|
||||
|
||||
def test_isblank_false(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isBLANK", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": iri("http://x")}) is False
|
||||
|
||||
def test_str(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("STR", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": iri("http://example.com/a")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "http://example.com/a"
|
||||
|
||||
def test_lang(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("LANG", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("hello", language="en")}
|
||||
)
|
||||
assert result.value == "en"
|
||||
|
||||
def test_lang_no_tag(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("LANG", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.value == ""
|
||||
|
||||
def test_datatype(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("DATATYPE", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("42", datatype=XSD + "integer")}
|
||||
)
|
||||
assert result.type == IRI
|
||||
assert result.iri == XSD + "integer"
|
||||
|
||||
def test_strlen(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("STRLEN", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result == 5
|
||||
|
||||
def test_ucase(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("UCASE", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.value == "HELLO"
|
||||
|
||||
def test_lcase(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("LCASE", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("HELLO")})
|
||||
assert result.value == "hello"
|
||||
|
||||
def test_contains_true(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("CONTAINS",
|
||||
arg1=Variable("x"), arg2=Literal("ell"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_contains_false(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("CONTAINS",
|
||||
arg1=Variable("x"), arg2=Literal("xyz"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
||||
|
||||
def test_strstarts_true(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("STRSTARTS",
|
||||
arg1=Variable("x"), arg2=Literal("hel"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_strends_true(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("STRENDS",
|
||||
arg1=Variable("x"), arg2=Literal("llo"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_regex_match(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REGEX",
|
||||
text=Variable("x"),
|
||||
pattern=Literal("^hel"),
|
||||
flags=None)
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_regex_case_insensitive(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REGEX",
|
||||
text=Variable("x"),
|
||||
pattern=Literal("HELLO"),
|
||||
flags=Literal("i"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_regex_no_match(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REGEX",
|
||||
text=Variable("x"),
|
||||
pattern=Literal("^world"),
|
||||
flags=None)
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
||||
|
||||
|
||||
class TestEffectiveBoolean:
|
||||
|
||||
def test_true(self):
|
||||
assert _effective_boolean(True) is True
|
||||
|
||||
def test_false(self):
|
||||
assert _effective_boolean(False) is False
|
||||
|
||||
def test_none(self):
|
||||
assert _effective_boolean(None) is False
|
||||
|
||||
def test_nonzero_int(self):
|
||||
assert _effective_boolean(42) is True
|
||||
|
||||
def test_zero_int(self):
|
||||
assert _effective_boolean(0) is False
|
||||
|
||||
def test_nonempty_string(self):
|
||||
assert _effective_boolean("hello") is True
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _effective_boolean("") is False
|
||||
|
||||
def test_iri_term(self):
|
||||
assert _effective_boolean(iri("http://x")) is True
|
||||
|
||||
def test_nonempty_literal(self):
|
||||
assert _effective_boolean(lit("hello")) is True
|
||||
|
||||
def test_empty_literal(self):
|
||||
assert _effective_boolean(lit("")) is False
|
||||
|
||||
def test_boolean_literal_true(self):
|
||||
assert _effective_boolean(
|
||||
lit("true", datatype=XSD + "boolean")
|
||||
) is True
|
||||
|
||||
def test_boolean_literal_false(self):
|
||||
assert _effective_boolean(
|
||||
lit("false", datatype=XSD + "boolean")
|
||||
) is False
|
||||
|
||||
def test_numeric_literal_nonzero(self):
|
||||
assert _effective_boolean(
|
||||
lit("42", datatype=XSD + "integer")
|
||||
) is True
|
||||
|
||||
def test_numeric_literal_zero(self):
|
||||
assert _effective_boolean(
|
||||
lit("0", datatype=XSD + "integer")
|
||||
) is False
|
||||
|
||||
|
||||
class TestToString:
|
||||
|
||||
def test_none(self):
|
||||
assert _to_string(None) == ""
|
||||
|
||||
def test_string(self):
|
||||
assert _to_string("hello") == "hello"
|
||||
|
||||
def test_iri_term(self):
|
||||
assert _to_string(iri("http://example.com")) == "http://example.com"
|
||||
|
||||
def test_literal_term(self):
|
||||
assert _to_string(lit("hello")) == "hello"
|
||||
|
||||
def test_blank_term(self):
|
||||
assert _to_string(blank("b1")) == "b1"
|
||||
|
||||
|
||||
class TestToNumeric:
|
||||
|
||||
def test_none(self):
|
||||
assert _to_numeric(None) is None
|
||||
|
||||
def test_int(self):
|
||||
assert _to_numeric(42) == 42
|
||||
|
||||
def test_float(self):
|
||||
assert _to_numeric(3.14) == 3.14
|
||||
|
||||
def test_integer_literal(self):
|
||||
assert _to_numeric(lit("42")) == 42
|
||||
|
||||
def test_decimal_literal(self):
|
||||
assert _to_numeric(lit("3.14")) == 3.14
|
||||
|
||||
def test_non_numeric_literal(self):
|
||||
assert _to_numeric(lit("hello")) is None
|
||||
|
||||
def test_numeric_string(self):
|
||||
assert _to_numeric("42") == 42
|
||||
|
||||
def test_non_numeric_string(self):
|
||||
assert _to_numeric("abc") is None
|
||||
|
||||
|
||||
class TestComparableValue:
|
||||
|
||||
def test_none(self):
|
||||
assert _comparable_value(None) == (0, "")
|
||||
|
||||
def test_int(self):
|
||||
assert _comparable_value(42) == (2, 42)
|
||||
|
||||
def test_iri(self):
|
||||
assert _comparable_value(iri("http://x")) == (4, "http://x")
|
||||
|
||||
def test_literal(self):
|
||||
assert _comparable_value(lit("hello")) == (3, "hello")
|
||||
|
||||
def test_numeric_literal(self):
|
||||
assert _comparable_value(lit("42")) == (2, 42)
|
||||
|
||||
def test_ordering(self):
|
||||
vals = [lit("b"), lit("a"), lit("c")]
|
||||
sorted_vals = sorted(vals, key=_comparable_value)
|
||||
assert sorted_vals[0].value == "a"
|
||||
assert sorted_vals[1].value == "b"
|
||||
assert sorted_vals[2].value == "c"
|
||||
205
tests/unit/test_query/test_sparql_parser.py
Normal file
205
tests/unit/test_query/test_sparql_parser.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
"""
|
||||
Tests for the SPARQL parser module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.query.sparql.parser import (
|
||||
parse_sparql, ParseError, rdflib_term_to_term, term_to_rdflib,
|
||||
)
|
||||
from trustgraph.schema import Term, IRI, LITERAL, BLANK
|
||||
|
||||
|
||||
class TestParseSparql:
|
||||
"""Tests for parse_sparql function."""
|
||||
|
||||
def test_select_query_type(self):
|
||||
parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
|
||||
assert parsed.query_type == "select"
|
||||
|
||||
def test_select_variables(self):
|
||||
parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
|
||||
assert parsed.variables == ["s", "p", "o"]
|
||||
|
||||
def test_select_subset_variables(self):
|
||||
parsed = parse_sparql("SELECT ?s ?o WHERE { ?s ?p ?o }")
|
||||
assert parsed.variables == ["s", "o"]
|
||||
|
||||
def test_ask_query_type(self):
|
||||
parsed = parse_sparql(
|
||||
"ASK { <http://example.com/a> ?p ?o }"
|
||||
)
|
||||
assert parsed.query_type == "ask"
|
||||
|
||||
def test_ask_no_variables(self):
|
||||
parsed = parse_sparql(
|
||||
"ASK { <http://example.com/a> ?p ?o }"
|
||||
)
|
||||
assert parsed.variables == []
|
||||
|
||||
def test_construct_query_type(self):
|
||||
parsed = parse_sparql(
|
||||
"CONSTRUCT { ?s <http://example.com/knows> ?o } "
|
||||
"WHERE { ?s <http://example.com/friendOf> ?o }"
|
||||
)
|
||||
assert parsed.query_type == "construct"
|
||||
|
||||
def test_describe_query_type(self):
|
||||
parsed = parse_sparql(
|
||||
"DESCRIBE <http://example.com/alice>"
|
||||
)
|
||||
assert parsed.query_type == "describe"
|
||||
|
||||
def test_select_with_limit(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?s WHERE { ?s ?p ?o } LIMIT 10"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert parsed.variables == ["s"]
|
||||
|
||||
def test_select_with_distinct(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT DISTINCT ?s WHERE { ?s ?p ?o }"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert parsed.variables == ["s"]
|
||||
|
||||
def test_select_with_filter(self):
|
||||
parsed = parse_sparql(
|
||||
'SELECT ?s ?label WHERE { '
|
||||
' ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label . '
|
||||
' FILTER(CONTAINS(STR(?label), "test")) '
|
||||
'}'
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert parsed.variables == ["s", "label"]
|
||||
|
||||
def test_select_with_optional(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?s ?p ?o ?label WHERE { "
|
||||
" ?s ?p ?o . "
|
||||
" OPTIONAL { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
|
||||
"}"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert set(parsed.variables) == {"s", "p", "o", "label"}
|
||||
|
||||
def test_select_with_union(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?s ?label WHERE { "
|
||||
" { ?s <http://example.com/name> ?label } "
|
||||
" UNION "
|
||||
" { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
|
||||
"}"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
|
||||
def test_select_with_order_by(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?s ?label WHERE { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
|
||||
"ORDER BY ?label"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
|
||||
def test_select_with_group_by(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?p (COUNT(?o) AS ?count) WHERE { ?s ?p ?o } "
|
||||
"GROUP BY ?p ORDER BY DESC(?count)"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
|
||||
def test_select_with_prefixes(self):
|
||||
parsed = parse_sparql(
|
||||
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
||||
"SELECT ?s ?label WHERE { ?s rdfs:label ?label }"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert parsed.variables == ["s", "label"]
|
||||
|
||||
def test_algebra_not_none(self):
|
||||
parsed = parse_sparql("SELECT ?s WHERE { ?s ?p ?o }")
|
||||
assert parsed.algebra is not None
|
||||
|
||||
def test_parse_error_invalid_sparql(self):
|
||||
with pytest.raises(ParseError):
|
||||
parse_sparql("NOT VALID SPARQL AT ALL")
|
||||
|
||||
def test_parse_error_incomplete_query(self):
|
||||
with pytest.raises(ParseError):
|
||||
parse_sparql("SELECT ?s WHERE {")
|
||||
|
||||
def test_parse_error_message(self):
|
||||
with pytest.raises(ParseError, match="SPARQL parse error"):
|
||||
parse_sparql("GIBBERISH")
|
||||
|
||||
|
||||
class TestRdflibTermToTerm:
|
||||
"""Tests for rdflib-to-Term conversion."""
|
||||
|
||||
def test_uriref_to_term(self):
|
||||
from rdflib import URIRef
|
||||
term = rdflib_term_to_term(URIRef("http://example.com/alice"))
|
||||
assert term.type == IRI
|
||||
assert term.iri == "http://example.com/alice"
|
||||
|
||||
def test_literal_to_term(self):
|
||||
from rdflib import Literal
|
||||
term = rdflib_term_to_term(Literal("hello"))
|
||||
assert term.type == LITERAL
|
||||
assert term.value == "hello"
|
||||
|
||||
def test_typed_literal_to_term(self):
|
||||
from rdflib import Literal, URIRef
|
||||
term = rdflib_term_to_term(
|
||||
Literal("42", datatype=URIRef("http://www.w3.org/2001/XMLSchema#integer"))
|
||||
)
|
||||
assert term.type == LITERAL
|
||||
assert term.value == "42"
|
||||
assert term.datatype == "http://www.w3.org/2001/XMLSchema#integer"
|
||||
|
||||
def test_lang_literal_to_term(self):
|
||||
from rdflib import Literal
|
||||
term = rdflib_term_to_term(Literal("hello", lang="en"))
|
||||
assert term.type == LITERAL
|
||||
assert term.value == "hello"
|
||||
assert term.language == "en"
|
||||
|
||||
def test_bnode_to_term(self):
|
||||
from rdflib import BNode
|
||||
term = rdflib_term_to_term(BNode("b1"))
|
||||
assert term.type == BLANK
|
||||
assert term.id == "b1"
|
||||
|
||||
|
||||
class TestTermToRdflib:
|
||||
"""Tests for Term-to-rdflib conversion."""
|
||||
|
||||
def test_iri_term_to_uriref(self):
|
||||
from rdflib import URIRef
|
||||
result = term_to_rdflib(Term(type=IRI, iri="http://example.com/x"))
|
||||
assert isinstance(result, URIRef)
|
||||
assert str(result) == "http://example.com/x"
|
||||
|
||||
def test_literal_term_to_literal(self):
|
||||
from rdflib import Literal
|
||||
result = term_to_rdflib(Term(type=LITERAL, value="hello"))
|
||||
assert isinstance(result, Literal)
|
||||
assert str(result) == "hello"
|
||||
|
||||
def test_typed_literal_roundtrip(self):
|
||||
from rdflib import URIRef
|
||||
original = Term(
|
||||
type=LITERAL, value="42",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#integer"
|
||||
)
|
||||
rdflib_term = term_to_rdflib(original)
|
||||
assert rdflib_term.datatype == URIRef("http://www.w3.org/2001/XMLSchema#integer")
|
||||
|
||||
def test_lang_literal_roundtrip(self):
|
||||
original = Term(type=LITERAL, value="bonjour", language="fr")
|
||||
rdflib_term = term_to_rdflib(original)
|
||||
assert rdflib_term.language == "fr"
|
||||
|
||||
def test_blank_term_to_bnode(self):
|
||||
from rdflib import BNode
|
||||
result = term_to_rdflib(Term(type=BLANK, id="b1"))
|
||||
assert isinstance(result, BNode)
|
||||
345
tests/unit/test_query/test_sparql_solutions.py
Normal file
345
tests/unit/test_query/test_sparql_solutions.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""
|
||||
Tests for SPARQL solution sequence operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.query.sparql.solutions import (
|
||||
hash_join, left_join, union, project, distinct,
|
||||
order_by, slice_solutions, _terms_equal, _compatible,
|
||||
)
|
||||
|
||||
|
||||
# --- Test helpers ---
|
||||
|
||||
def iri(v):
|
||||
return Term(type=IRI, iri=v)
|
||||
|
||||
def lit(v):
|
||||
return Term(type=LITERAL, value=v)
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
@pytest.fixture
|
||||
def alice():
|
||||
return iri("http://example.com/alice")
|
||||
|
||||
@pytest.fixture
|
||||
def bob():
|
||||
return iri("http://example.com/bob")
|
||||
|
||||
@pytest.fixture
|
||||
def carol():
|
||||
return iri("http://example.com/carol")
|
||||
|
||||
@pytest.fixture
|
||||
def knows():
|
||||
return iri("http://example.com/knows")
|
||||
|
||||
@pytest.fixture
|
||||
def name_alice():
|
||||
return lit("Alice")
|
||||
|
||||
@pytest.fixture
|
||||
def name_bob():
|
||||
return lit("Bob")
|
||||
|
||||
|
||||
class TestTermsEqual:
|
||||
|
||||
def test_equal_iris(self):
|
||||
assert _terms_equal(iri("http://x.com/a"), iri("http://x.com/a"))
|
||||
|
||||
def test_unequal_iris(self):
|
||||
assert not _terms_equal(iri("http://x.com/a"), iri("http://x.com/b"))
|
||||
|
||||
def test_equal_literals(self):
|
||||
assert _terms_equal(lit("hello"), lit("hello"))
|
||||
|
||||
def test_unequal_literals(self):
|
||||
assert not _terms_equal(lit("hello"), lit("world"))
|
||||
|
||||
def test_iri_vs_literal(self):
|
||||
assert not _terms_equal(iri("hello"), lit("hello"))
|
||||
|
||||
def test_none_none(self):
|
||||
assert _terms_equal(None, None)
|
||||
|
||||
def test_none_vs_term(self):
|
||||
assert not _terms_equal(None, iri("http://x.com/a"))
|
||||
|
||||
|
||||
class TestCompatible:
|
||||
|
||||
def test_no_shared_variables(self):
|
||||
assert _compatible({"a": iri("http://x")}, {"b": iri("http://y")})
|
||||
|
||||
def test_shared_variable_same_value(self, alice):
|
||||
assert _compatible({"s": alice, "x": lit("1")}, {"s": alice, "y": lit("2")})
|
||||
|
||||
def test_shared_variable_different_value(self, alice, bob):
|
||||
assert not _compatible({"s": alice}, {"s": bob})
|
||||
|
||||
def test_empty_solutions(self):
|
||||
assert _compatible({}, {})
|
||||
|
||||
def test_empty_vs_nonempty(self, alice):
|
||||
assert _compatible({}, {"s": alice})
|
||||
|
||||
|
||||
class TestHashJoin:
|
||||
|
||||
def test_join_on_shared_variable(self, alice, bob, name_alice, name_bob):
|
||||
left = [
|
||||
{"s": alice, "p": iri("http://example.com/knows"), "o": bob},
|
||||
{"s": bob, "p": iri("http://example.com/knows"), "o": alice},
|
||||
]
|
||||
right = [
|
||||
{"s": alice, "label": name_alice},
|
||||
{"s": bob, "label": name_bob},
|
||||
]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 2
|
||||
# Check that joined solutions have all variables
|
||||
for sol in result:
|
||||
assert "s" in sol
|
||||
assert "p" in sol
|
||||
assert "o" in sol
|
||||
assert "label" in sol
|
||||
|
||||
def test_join_no_shared_variables_cross_product(self, alice, bob):
|
||||
left = [{"a": alice}]
|
||||
right = [{"b": bob}, {"b": alice}]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_join_no_matches(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": bob}]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_join_empty_left(self, alice):
|
||||
result = hash_join([], [{"s": alice}])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_join_empty_right(self, alice):
|
||||
result = hash_join([{"s": alice}], [])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_join_multiple_matches(self, alice, name_alice):
|
||||
left = [
|
||||
{"s": alice, "p": iri("http://e.com/a")},
|
||||
{"s": alice, "p": iri("http://e.com/b")},
|
||||
]
|
||||
right = [{"s": alice, "label": name_alice}]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_join_preserves_values(self, alice, name_alice):
|
||||
left = [{"s": alice, "x": lit("1")}]
|
||||
right = [{"s": alice, "y": lit("2")}]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 1
|
||||
assert result[0]["x"].value == "1"
|
||||
assert result[0]["y"].value == "2"
|
||||
|
||||
|
||||
class TestLeftJoin:
|
||||
|
||||
def test_left_join_with_matches(self, alice, bob, name_alice):
|
||||
left = [{"s": alice}, {"s": bob}]
|
||||
right = [{"s": alice, "label": name_alice}]
|
||||
result = left_join(left, right)
|
||||
assert len(result) == 2
|
||||
# Alice has label
|
||||
alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"]
|
||||
assert len(alice_sols) == 1
|
||||
assert "label" in alice_sols[0]
|
||||
# Bob preserved without label
|
||||
bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"]
|
||||
assert len(bob_sols) == 1
|
||||
assert "label" not in bob_sols[0]
|
||||
|
||||
def test_left_join_no_matches(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": bob, "label": lit("Bob")}]
|
||||
result = left_join(left, right)
|
||||
assert len(result) == 1
|
||||
assert result[0]["s"].iri == "http://example.com/alice"
|
||||
assert "label" not in result[0]
|
||||
|
||||
def test_left_join_empty_right(self, alice):
|
||||
left = [{"s": alice}]
|
||||
result = left_join(left, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_left_join_empty_left(self):
|
||||
result = left_join([], [{"s": iri("http://x")}])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_left_join_with_filter(self, alice, bob):
|
||||
left = [{"s": alice}, {"s": bob}]
|
||||
right = [
|
||||
{"s": alice, "val": lit("yes")},
|
||||
{"s": bob, "val": lit("no")},
|
||||
]
|
||||
# Filter: only keep joins where val == "yes"
|
||||
result = left_join(
|
||||
left, right,
|
||||
filter_fn=lambda sol: sol.get("val") and sol["val"].value == "yes"
|
||||
)
|
||||
assert len(result) == 2
|
||||
# Alice matches filter
|
||||
alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"]
|
||||
assert "val" in alice_sols[0]
|
||||
assert alice_sols[0]["val"].value == "yes"
|
||||
# Bob doesn't match filter, preserved without val
|
||||
bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"]
|
||||
assert "val" not in bob_sols[0]
|
||||
|
||||
|
||||
class TestUnion:
|
||||
|
||||
def test_union_concatenates(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": bob}]
|
||||
result = union(left, right)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_union_preserves_order(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": bob}]
|
||||
result = union(left, right)
|
||||
assert result[0]["s"].iri == "http://example.com/alice"
|
||||
assert result[1]["s"].iri == "http://example.com/bob"
|
||||
|
||||
def test_union_empty_left(self, alice):
|
||||
result = union([], [{"s": alice}])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_union_both_empty(self):
|
||||
result = union([], [])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_union_allows_duplicates(self, alice):
|
||||
result = union([{"s": alice}], [{"s": alice}])
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestProject:
|
||||
|
||||
def test_project_keeps_selected(self, alice, name_alice):
|
||||
solutions = [{"s": alice, "label": name_alice, "extra": lit("x")}]
|
||||
result = project(solutions, ["s", "label"])
|
||||
assert len(result) == 1
|
||||
assert "s" in result[0]
|
||||
assert "label" in result[0]
|
||||
assert "extra" not in result[0]
|
||||
|
||||
def test_project_missing_variable(self, alice):
|
||||
solutions = [{"s": alice}]
|
||||
result = project(solutions, ["s", "missing"])
|
||||
assert len(result) == 1
|
||||
assert "s" in result[0]
|
||||
assert "missing" not in result[0]
|
||||
|
||||
def test_project_empty(self):
|
||||
result = project([], ["s"])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestDistinct:
|
||||
|
||||
def test_removes_duplicates(self, alice):
|
||||
solutions = [{"s": alice}, {"s": alice}, {"s": alice}]
|
||||
result = distinct(solutions)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_keeps_different(self, alice, bob):
|
||||
solutions = [{"s": alice}, {"s": bob}]
|
||||
result = distinct(solutions)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_empty(self):
|
||||
result = distinct([])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_multi_variable_distinct(self, alice, bob):
|
||||
solutions = [
|
||||
{"s": alice, "o": bob},
|
||||
{"s": alice, "o": bob},
|
||||
{"s": alice, "o": alice},
|
||||
]
|
||||
result = distinct(solutions)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestOrderBy:
|
||||
|
||||
def test_order_by_ascending(self):
|
||||
solutions = [
|
||||
{"label": lit("Charlie")},
|
||||
{"label": lit("Alice")},
|
||||
{"label": lit("Bob")},
|
||||
]
|
||||
key_fns = [(lambda sol: sol.get("label"), True)]
|
||||
result = order_by(solutions, key_fns)
|
||||
assert result[0]["label"].value == "Alice"
|
||||
assert result[1]["label"].value == "Bob"
|
||||
assert result[2]["label"].value == "Charlie"
|
||||
|
||||
def test_order_by_descending(self):
|
||||
solutions = [
|
||||
{"label": lit("Alice")},
|
||||
{"label": lit("Charlie")},
|
||||
{"label": lit("Bob")},
|
||||
]
|
||||
key_fns = [(lambda sol: sol.get("label"), False)]
|
||||
result = order_by(solutions, key_fns)
|
||||
assert result[0]["label"].value == "Charlie"
|
||||
assert result[1]["label"].value == "Bob"
|
||||
assert result[2]["label"].value == "Alice"
|
||||
|
||||
def test_order_by_empty(self):
|
||||
result = order_by([], [(lambda sol: sol.get("x"), True)])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_order_by_no_keys(self, alice):
|
||||
solutions = [{"s": alice}]
|
||||
result = order_by(solutions, [])
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestSlice:
|
||||
|
||||
def test_limit(self, alice, bob, carol):
|
||||
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
|
||||
result = slice_solutions(solutions, limit=2)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_offset(self, alice, bob, carol):
|
||||
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
|
||||
result = slice_solutions(solutions, offset=1)
|
||||
assert len(result) == 2
|
||||
assert result[0]["s"].iri == "http://example.com/bob"
|
||||
|
||||
def test_offset_and_limit(self, alice, bob, carol):
|
||||
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
|
||||
result = slice_solutions(solutions, offset=1, limit=1)
|
||||
assert len(result) == 1
|
||||
assert result[0]["s"].iri == "http://example.com/bob"
|
||||
|
||||
def test_limit_zero(self, alice):
|
||||
result = slice_solutions([{"s": alice}], limit=0)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_offset_beyond_length(self, alice):
|
||||
result = slice_solutions([{"s": alice}], offset=10)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_no_slice(self, alice, bob):
|
||||
solutions = [{"s": alice}, {"s": bob}]
|
||||
result = slice_solutions(solutions)
|
||||
assert len(result) == 2
|
||||
|
|
@ -28,21 +28,21 @@ def triple_tx():
|
|||
|
||||
class TestTermTranslatorIri:
|
||||
|
||||
def test_iri_to_pulsar(self, term_tx):
|
||||
def test_iri_decode(self, term_tx):
|
||||
data = {"t": "i", "i": "http://example.org/Alice"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == IRI
|
||||
assert term.iri == "http://example.org/Alice"
|
||||
|
||||
def test_iri_from_pulsar(self, term_tx):
|
||||
def test_iri_encode(self, term_tx):
|
||||
term = Term(type=IRI, iri="http://example.org/Bob")
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire == {"t": "i", "i": "http://example.org/Bob"}
|
||||
|
||||
def test_iri_round_trip(self, term_tx):
|
||||
original = Term(type=IRI, iri="http://example.org/round")
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
|
||||
|
|
@ -52,21 +52,21 @@ class TestTermTranslatorIri:
|
|||
|
||||
class TestTermTranslatorBlank:
|
||||
|
||||
def test_blank_to_pulsar(self, term_tx):
|
||||
def test_blank_decode(self, term_tx):
|
||||
data = {"t": "b", "d": "_:b42"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == BLANK
|
||||
assert term.id == "_:b42"
|
||||
|
||||
def test_blank_from_pulsar(self, term_tx):
|
||||
def test_blank_encode(self, term_tx):
|
||||
term = Term(type=BLANK, id="_:node1")
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire == {"t": "b", "d": "_:node1"}
|
||||
|
||||
def test_blank_round_trip(self, term_tx):
|
||||
original = Term(type=BLANK, id="_:x")
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
|
||||
|
|
@ -76,29 +76,29 @@ class TestTermTranslatorBlank:
|
|||
|
||||
class TestTermTranslatorTypedLiteral:
|
||||
|
||||
def test_plain_literal_to_pulsar(self, term_tx):
|
||||
def test_plain_literal_decode(self, term_tx):
|
||||
data = {"t": "l", "v": "hello"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == LITERAL
|
||||
assert term.value == "hello"
|
||||
assert term.datatype == ""
|
||||
assert term.language == ""
|
||||
|
||||
def test_xsd_integer_to_pulsar(self, term_tx):
|
||||
def test_xsd_integer_decode(self, term_tx):
|
||||
data = {
|
||||
"t": "l", "v": "42",
|
||||
"dt": "http://www.w3.org/2001/XMLSchema#integer",
|
||||
}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.value == "42"
|
||||
assert term.datatype.endswith("#integer")
|
||||
|
||||
def test_typed_literal_from_pulsar(self, term_tx):
|
||||
def test_typed_literal_encode(self, term_tx):
|
||||
term = Term(
|
||||
type=LITERAL, value="3.14",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#double",
|
||||
)
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire["t"] == "l"
|
||||
assert wire["v"] == "3.14"
|
||||
assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double"
|
||||
|
|
@ -109,13 +109,13 @@ class TestTermTranslatorTypedLiteral:
|
|||
type=LITERAL, value="true",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#boolean",
|
||||
)
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
def test_plain_literal_omits_dt_and_ln(self, term_tx):
|
||||
term = Term(type=LITERAL, value="x")
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert "dt" not in wire
|
||||
assert "ln" not in wire
|
||||
|
||||
|
|
@ -126,22 +126,22 @@ class TestTermTranslatorTypedLiteral:
|
|||
|
||||
class TestTermTranslatorLangLiteral:
|
||||
|
||||
def test_language_tag_to_pulsar(self, term_tx):
|
||||
def test_language_tag_decode(self, term_tx):
|
||||
data = {"t": "l", "v": "bonjour", "ln": "fr"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.value == "bonjour"
|
||||
assert term.language == "fr"
|
||||
|
||||
def test_language_tag_from_pulsar(self, term_tx):
|
||||
def test_language_tag_encode(self, term_tx):
|
||||
term = Term(type=LITERAL, value="colour", language="en-GB")
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire["ln"] == "en-GB"
|
||||
assert "dt" not in wire # No datatype
|
||||
|
||||
def test_language_tag_round_trip(self, term_tx):
|
||||
original = Term(type=LITERAL, value="hola", language="es")
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ class TestTermTranslatorLangLiteral:
|
|||
|
||||
class TestTermTranslatorQuotedTriple:
|
||||
|
||||
def test_quoted_triple_to_pulsar(self, term_tx):
|
||||
def test_quoted_triple_decode(self, term_tx):
|
||||
data = {
|
||||
"t": "t",
|
||||
"tr": {
|
||||
|
|
@ -160,20 +160,20 @@ class TestTermTranslatorQuotedTriple:
|
|||
"o": {"t": "i", "i": "http://example.org/Bob"},
|
||||
},
|
||||
}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == TRIPLE
|
||||
assert term.triple is not None
|
||||
assert term.triple.s.iri == "http://example.org/Alice"
|
||||
assert term.triple.o.iri == "http://example.org/Bob"
|
||||
|
||||
def test_quoted_triple_from_pulsar(self, term_tx):
|
||||
def test_quoted_triple_encode(self, term_tx):
|
||||
inner = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/s"),
|
||||
p=Term(type=IRI, iri="http://example.org/p"),
|
||||
o=Term(type=LITERAL, value="val"),
|
||||
)
|
||||
term = Term(type=TRIPLE, triple=inner)
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire["t"] == "t"
|
||||
assert "tr" in wire
|
||||
assert wire["tr"]["s"]["i"] == "http://example.org/s"
|
||||
|
|
@ -186,18 +186,18 @@ class TestTermTranslatorQuotedTriple:
|
|||
o=Term(type=LITERAL, value="C", language="en"),
|
||||
)
|
||||
original = Term(type=TRIPLE, triple=inner)
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored.type == TRIPLE
|
||||
assert restored.triple.s == original.triple.s
|
||||
assert restored.triple.o == original.triple.o
|
||||
|
||||
def test_quoted_triple_none_triple(self, term_tx):
|
||||
term = Term(type=TRIPLE, triple=None)
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire == {"t": "t"}
|
||||
# And back
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored.type == TRIPLE
|
||||
assert restored.triple is None
|
||||
|
||||
|
|
@ -210,7 +210,7 @@ class TestTermTranslatorQuotedTriple:
|
|||
"o": {"t": "l", "v": "A feeling of expectation"},
|
||||
},
|
||||
}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.triple.o.type == LITERAL
|
||||
assert term.triple.o.value == "A feeling of expectation"
|
||||
|
||||
|
|
@ -223,22 +223,22 @@ class TestTermTranslatorEdgeCases:
|
|||
|
||||
def test_unknown_type(self, term_tx):
|
||||
data = {"t": "z"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == "z"
|
||||
|
||||
def test_empty_type(self, term_tx):
|
||||
data = {}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == ""
|
||||
|
||||
def test_missing_iri_field(self, term_tx):
|
||||
data = {"t": "i"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.iri == ""
|
||||
|
||||
def test_missing_literal_fields(self, term_tx):
|
||||
data = {"t": "l"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.value == ""
|
||||
assert term.datatype == ""
|
||||
assert term.language == ""
|
||||
|
|
@ -250,24 +250,24 @@ class TestTermTranslatorEdgeCases:
|
|||
|
||||
class TestTripleTranslator:
|
||||
|
||||
def test_triple_to_pulsar(self, triple_tx):
|
||||
def test_triple_decode(self, triple_tx):
|
||||
data = {
|
||||
"s": {"t": "i", "i": "http://example.org/s"},
|
||||
"p": {"t": "i", "i": "http://example.org/p"},
|
||||
"o": {"t": "l", "v": "object"},
|
||||
}
|
||||
triple = triple_tx.to_pulsar(data)
|
||||
triple = triple_tx.decode(data)
|
||||
assert triple.s.iri == "http://example.org/s"
|
||||
assert triple.o.value == "object"
|
||||
assert triple.g is None
|
||||
|
||||
def test_triple_from_pulsar(self, triple_tx):
|
||||
def test_triple_encode(self, triple_tx):
|
||||
triple = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/A"),
|
||||
p=Term(type=IRI, iri="http://example.org/B"),
|
||||
o=Term(type=LITERAL, value="C"),
|
||||
)
|
||||
wire = triple_tx.from_pulsar(triple)
|
||||
wire = triple_tx.encode(triple)
|
||||
assert wire["s"]["t"] == "i"
|
||||
assert wire["o"]["v"] == "C"
|
||||
assert "g" not in wire
|
||||
|
|
@ -279,17 +279,17 @@ class TestTripleTranslator:
|
|||
"o": {"t": "l", "v": "val"},
|
||||
"g": "urn:graph:source",
|
||||
}
|
||||
quad = triple_tx.to_pulsar(data)
|
||||
quad = triple_tx.decode(data)
|
||||
assert quad.g == "urn:graph:source"
|
||||
|
||||
def test_quad_from_pulsar_includes_graph(self, triple_tx):
|
||||
def test_quad_encode_includes_graph(self, triple_tx):
|
||||
quad = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/s"),
|
||||
p=Term(type=IRI, iri="http://example.org/p"),
|
||||
o=Term(type=LITERAL, value="v"),
|
||||
g="urn:graph:retrieval",
|
||||
)
|
||||
wire = triple_tx.from_pulsar(quad)
|
||||
wire = triple_tx.encode(quad)
|
||||
assert wire["g"] == "urn:graph:retrieval"
|
||||
|
||||
def test_quad_round_trip(self, triple_tx):
|
||||
|
|
@ -299,8 +299,8 @@ class TestTripleTranslator:
|
|||
o=Term(type=LITERAL, value="v"),
|
||||
g="urn:graph:source",
|
||||
)
|
||||
wire = triple_tx.from_pulsar(original)
|
||||
restored = triple_tx.to_pulsar(wire)
|
||||
wire = triple_tx.encode(original)
|
||||
restored = triple_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
def test_none_graph_omitted_from_wire(self, triple_tx):
|
||||
|
|
@ -310,12 +310,12 @@ class TestTripleTranslator:
|
|||
o=Term(type=LITERAL, value="v"),
|
||||
g=None,
|
||||
)
|
||||
wire = triple_tx.from_pulsar(triple)
|
||||
wire = triple_tx.encode(triple)
|
||||
assert "g" not in wire
|
||||
|
||||
def test_missing_terms_handled(self, triple_tx):
|
||||
data = {}
|
||||
triple = triple_tx.to_pulsar(data)
|
||||
triple = triple_tx.decode(data)
|
||||
assert triple.s is None
|
||||
assert triple.p is None
|
||||
assert triple.o is None
|
||||
|
|
@ -342,16 +342,16 @@ class TestSubgraphTranslator:
|
|||
g="urn:graph:source",
|
||||
),
|
||||
]
|
||||
wire_list = tx.from_pulsar(triples)
|
||||
wire_list = tx.encode(triples)
|
||||
assert len(wire_list) == 2
|
||||
assert wire_list[1]["g"] == "urn:graph:source"
|
||||
|
||||
restored = tx.to_pulsar(wire_list)
|
||||
restored = tx.decode(wire_list)
|
||||
assert len(restored) == 2
|
||||
assert restored[0] == triples[0]
|
||||
assert restored[1] == triples[1]
|
||||
|
||||
def test_empty_subgraph(self):
|
||||
tx = SubgraphTranslator()
|
||||
assert tx.to_pulsar([]) == []
|
||||
assert tx.from_pulsar([]) == []
|
||||
assert tx.decode([]) == []
|
||||
assert tx.encode([]) == []
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class TestDocumentMetadataTranslator:
|
|||
"parent-id": "doc-100",
|
||||
"document-type": "page",
|
||||
}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
obj = self.tx.decode(data)
|
||||
assert obj.id == "doc-123"
|
||||
assert obj.time == 1710000000
|
||||
assert obj.kind == "application/pdf"
|
||||
|
|
@ -45,14 +45,14 @@ class TestDocumentMetadataTranslator:
|
|||
assert obj.parent_id == "doc-100"
|
||||
assert obj.document_type == "page"
|
||||
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["id"] == "doc-123"
|
||||
assert wire["user"] == "alice"
|
||||
assert wire["parent-id"] == "doc-100"
|
||||
assert wire["document-type"] == "page"
|
||||
|
||||
def test_defaults_for_missing_fields(self):
|
||||
obj = self.tx.to_pulsar({})
|
||||
obj = self.tx.decode({})
|
||||
assert obj.parent_id == ""
|
||||
assert obj.document_type == "source"
|
||||
|
||||
|
|
@ -63,25 +63,25 @@ class TestDocumentMetadataTranslator:
|
|||
"o": {"t": "i", "i": "http://example.org/o"},
|
||||
}]
|
||||
data = {"metadata": triple_wire}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
obj = self.tx.decode(data)
|
||||
assert len(obj.metadata) == 1
|
||||
assert obj.metadata[0].s.iri == "http://example.org/s"
|
||||
|
||||
def test_none_metadata_handled(self):
|
||||
data = {"metadata": None}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
obj = self.tx.decode(data)
|
||||
assert obj.metadata == []
|
||||
|
||||
def test_empty_tags_preserved(self):
|
||||
data = {"tags": []}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
obj = self.tx.decode(data)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["tags"] == []
|
||||
|
||||
def test_falsy_fields_omitted_from_wire(self):
|
||||
"""Empty string fields should be omitted from wire format."""
|
||||
obj = DocumentMetadata(id="", time=0, user="")
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert "id" not in wire
|
||||
assert "user" not in wire
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class TestProcessingMetadataTranslator:
|
|||
"collection": "my-collection",
|
||||
"tags": ["tag1"],
|
||||
}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
obj = self.tx.decode(data)
|
||||
assert obj.id == "proc-1"
|
||||
assert obj.document_id == "doc-123"
|
||||
assert obj.flow == "default"
|
||||
|
|
@ -113,32 +113,32 @@ class TestProcessingMetadataTranslator:
|
|||
assert obj.collection == "my-collection"
|
||||
assert obj.tags == ["tag1"]
|
||||
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["id"] == "proc-1"
|
||||
assert wire["document-id"] == "doc-123"
|
||||
assert wire["user"] == "alice"
|
||||
assert wire["collection"] == "my-collection"
|
||||
|
||||
def test_missing_fields_use_defaults(self):
|
||||
obj = self.tx.to_pulsar({})
|
||||
obj = self.tx.decode({})
|
||||
assert obj.id is None
|
||||
assert obj.user is None
|
||||
assert obj.collection is None
|
||||
|
||||
def test_tags_none_omitted(self):
|
||||
obj = ProcessingMetadata(tags=None)
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert "tags" not in wire
|
||||
|
||||
def test_tags_empty_list_preserved(self):
|
||||
obj = ProcessingMetadata(tags=[])
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["tags"] == []
|
||||
|
||||
def test_user_and_collection_preserved(self):
|
||||
"""Core pipeline routing fields must survive round-trip."""
|
||||
data = {"user": "bob", "collection": "research"}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
obj = self.tx.decode(data)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["user"] == "bob"
|
||||
assert wire["collection"] == "research"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,380 @@
|
|||
"""
|
||||
Integration test: run a full DocumentRag.query() with mocked subsidiary
|
||||
clients and verify the explain_callback receives the complete provenance
|
||||
chain in the correct order with correct structure.
|
||||
|
||||
Document-RAG provenance chain (4 stages):
|
||||
question → grounding → exploration → synthesis
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT,
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMatch:
|
||||
"""Mimics the result from doc_embeddings_client.query()."""
|
||||
chunk_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
|
||||
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
|
||||
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
|
||||
CHUNK_B_CONTENT = "Refunds are processed to the original payment method."
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock clients for a document-rag query.
|
||||
|
||||
Client call sequence during query():
|
||||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. doc_embeddings_client.query(vector, ...) -> chunk matches
|
||||
4. fetch_chunk(chunk_id, user) -> chunk content
|
||||
5. prompt_client.document_prompt(query, documents) -> answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return "return policy\nrefund"
|
||||
return ""
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
# 2. Embedding vectors
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# 3. Chunk matching
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id=CHUNK_A),
|
||||
ChunkMatch(chunk_id=CHUNK_B),
|
||||
]
|
||||
|
||||
# 4. Chunk content
|
||||
async def mock_fetch(chunk_id, user):
|
||||
return {
|
||||
CHUNK_A: CHUNK_A_CONTENT,
|
||||
CHUNK_B: CHUNK_B_CONTENT,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
# 5. Synthesis
|
||||
prompt_client.document_prompt.return_value = (
|
||||
"Items can be returned within 30 days for a full refund."
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentRagQueryProvenance:
|
||||
"""
|
||||
Run a real DocumentRag.query() and verify the provenance chain emitted
|
||||
via explain_callback.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_four_events(self):
|
||||
"""query() should emit exactly 4 explain events."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert len(events) == 4, (
|
||||
f"Expected 4 explain events (question, grounding, exploration, "
|
||||
f"synthesis), got {len(events)}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_events_have_correct_types_in_order(self):
|
||||
"""
|
||||
Events should arrive as:
|
||||
question, grounding, exploration, synthesis.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_SYNTHESIS,
|
||||
]
|
||||
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
uri = events[i]["explain_id"]
|
||||
triples = events[i]["triples"]
|
||||
assert has_type(triples, uri, expected_type), (
|
||||
f"Event {i} (uri={uri}) should have type {expected_type}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_derivation_chain_links_correctly(self):
|
||||
"""
|
||||
Each event's URI should link to the previous via wasDerivedFrom:
|
||||
question → (none)
|
||||
grounding → question
|
||||
exploration → grounding
|
||||
synthesis → exploration
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
|
||||
# question has no parent
|
||||
assert derived_from(all_triples, uris[0]) is None
|
||||
|
||||
# grounding → question
|
||||
assert derived_from(all_triples, uris[1]) == uris[0]
|
||||
|
||||
# exploration → grounding
|
||||
assert derived_from(all_triples, uris[2]) == uris[1]
|
||||
|
||||
# synthesis → exploration
|
||||
assert derived_from(all_triples, uris[3]) == uris[2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_carries_query_text(self):
|
||||
"""The question event should contain the original query string."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
t = find_triple(q_triples, TG_QUERY, q_uri)
|
||||
assert t is not None
|
||||
assert t.o.value == "What is the return policy?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grounding_carries_concepts(self):
|
||||
"""The grounding event should list extracted concepts."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
gnd_triples = events[1]["triples"]
|
||||
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
|
||||
concept_values = {t.o.value for t in concepts}
|
||||
assert "return policy" in concept_values
|
||||
assert "refund" in concept_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_chunk_count(self):
|
||||
"""The exploration event should report the number of chunks retrieved."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
t = find_triple(exp_triples, TG_CHUNK_COUNT, exp_uri)
|
||||
assert t is not None
|
||||
assert int(t.o.value) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_selected_chunks(self):
|
||||
"""The exploration event should list the chunk IDs that were fetched."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
chunks = find_triples(exp_triples, TG_SELECTED_CHUNK, exp_uri)
|
||||
chunk_iris = {t.o.iri for t in chunks}
|
||||
assert CHUNK_A in chunk_iris
|
||||
assert CHUNK_B in chunk_iris
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
"""The synthesis event should have tg:Synthesis and tg:Answer types."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
syn_uri = events[3]["explain_id"]
|
||||
syn_triples = events[3]["triples"]
|
||||
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
|
||||
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_answer_text(self):
|
||||
"""query() should return the synthesised answer."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
"""query() without explain_callback should return answer normally."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(query="What is the return policy?")
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
"""All emitted triples should be in the urn:graph:retrieval graph."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
for t in event["triples"]:
|
||||
assert t.g == "urn:graph:retrieval", (
|
||||
f"Triple {t.s.iri} {t.p.iri} should be in "
|
||||
f"urn:graph:retrieval, got {t.g}"
|
||||
)
|
||||
|
|
@ -465,12 +465,15 @@ class TestQuery:
|
|||
return_value=(["entity1", "entity2"], ["concept1"])
|
||||
)
|
||||
|
||||
query.follow_edges_batch = AsyncMock(return_value={
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
})
|
||||
query.follow_edges_batch = AsyncMock(return_value=(
|
||||
{
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
},
|
||||
{}
|
||||
))
|
||||
|
||||
subgraph, entities, concepts = await query.get_subgraph("test query")
|
||||
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
|
||||
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
|
@ -503,7 +506,7 @@ class TestQuery:
|
|||
test_entities = ["entity1", "entity3"]
|
||||
test_concepts = ["concept1"]
|
||||
query.get_subgraph = AsyncMock(
|
||||
return_value=(test_subgraph, test_entities, test_concepts)
|
||||
return_value=(test_subgraph, {}, test_entities, test_concepts)
|
||||
)
|
||||
|
||||
async def mock_maybe_label(entity):
|
||||
|
|
|
|||
358
tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py
Normal file
358
tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
"""
|
||||
Tests that explain_triples are forwarded correctly through the graph-rag
|
||||
service and client layers.
|
||||
|
||||
Covers:
|
||||
- Service: explain messages include triples from the provenance callback
|
||||
- Client: explain_callback receives explain_triples from the response
|
||||
- End-to-end: triples survive the full service → client → callback chain
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.schema import (
|
||||
GraphRagQuery, GraphRagResponse,
|
||||
Triple, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base.graph_rag_client import GraphRagClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_triple(s_iri, p_iri, o_value, o_type=IRI):
|
||||
"""Create a Triple with IRI subject/predicate and typed object."""
|
||||
o = (
|
||||
Term(type=IRI, iri=o_value) if o_type == IRI
|
||||
else Term(type=LITERAL, value=o_value)
|
||||
)
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=s_iri),
|
||||
p=Term(type=IRI, iri=p_iri),
|
||||
o=o,
|
||||
)
|
||||
|
||||
|
||||
def sample_focus_triples():
|
||||
"""Focus-style triples with a quoted triple (edge selection)."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/Focus",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"http://www.w3.org/ns/prov#wasDerivedFrom",
|
||||
"urn:trustgraph:exploration:abc",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"https://trustgraph.ai/ns/selectedEdge",
|
||||
"urn:trustgraph:edge-sel:abc:0",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def sample_question_triples():
|
||||
"""Question-style triples."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/GraphRagQuestion",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc",
|
||||
"https://trustgraph.ai/ns/query",
|
||||
"What is quantum computing?",
|
||||
o_type=LITERAL,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service-level: explain messages carry triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagServiceExplainTriples:
|
||||
"""Test that the graph-rag service includes explain_triples in messages."""
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_messages_include_triples(self, mock_graph_rag_class):
|
||||
"""
|
||||
When the provenance callback is invoked with triples, the service
|
||||
should include them in the explain response message.
|
||||
"""
|
||||
from trustgraph.retrieval.graph_rag.rag import Processor
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
)
|
||||
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
|
||||
question_triples = sample_question_triples()
|
||||
focus_triples = sample_focus_triples()
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
explain_callback = kwargs.get('explain_callback')
|
||||
if explain_callback:
|
||||
await explain_callback(
|
||||
question_triples, "urn:trustgraph:question:abc"
|
||||
)
|
||||
await explain_callback(
|
||||
focus_triples, "urn:trustgraph:focus:abc"
|
||||
)
|
||||
return "The answer."
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is quantum computing?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_provenance = AsyncMock()
|
||||
|
||||
def flow_router(name):
|
||||
if name == "response":
|
||||
return mock_response
|
||||
if name == "explainability":
|
||||
return mock_provenance
|
||||
return AsyncMock()
|
||||
|
||||
flow.side_effect = flow_router
|
||||
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Find the explain messages
|
||||
explain_msgs = [
|
||||
call[0][0]
|
||||
for call in mock_response.send.call_args_list
|
||||
if call[0][0].message_type == "explain"
|
||||
]
|
||||
|
||||
assert len(explain_msgs) == 2
|
||||
|
||||
# First explain message should carry question triples
|
||||
assert explain_msgs[0].explain_id == "urn:trustgraph:question:abc"
|
||||
assert explain_msgs[0].explain_triples == question_triples
|
||||
|
||||
# Second explain message should carry focus triples
|
||||
assert explain_msgs[1].explain_id == "urn:trustgraph:focus:abc"
|
||||
assert explain_msgs[1].explain_triples == focus_triples
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client-level: explain_callback receives triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagClientExplainForwarding:
|
||||
"""Test that GraphRagClient.rag() forwards explain_triples to callback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_triples(self):
|
||||
"""
|
||||
The explain_callback should receive (explain_id, explain_graph,
|
||||
explain_triples) — not just (explain_id, explain_graph).
|
||||
"""
|
||||
focus_triples = sample_focus_triples()
|
||||
|
||||
# Simulate the response sequence the client would receive
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:focus:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=focus_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="The answer.",
|
||||
end_of_stream=True,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="",
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Capture what the explain_callback receives
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append({
|
||||
"explain_id": explain_id,
|
||||
"explain_graph": explain_graph,
|
||||
"explain_triples": explain_triples,
|
||||
})
|
||||
|
||||
# Patch self.request to feed responses to the recipient
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
result = await client.rag(
|
||||
query="test",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert result == "The answer."
|
||||
assert len(received_calls) == 1
|
||||
assert received_calls[0]["explain_id"] == "urn:trustgraph:focus:abc"
|
||||
assert received_calls[0]["explain_graph"] == "urn:graph:retrieval"
|
||||
assert received_calls[0]["explain_triples"] == focus_triples
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_empty_triples(self):
|
||||
"""
|
||||
When an explain event has no triples, the callback should still
|
||||
receive an empty list (not None or missing).
|
||||
"""
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=[],
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append(explain_triples)
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
await client.rag(query="test", explain_callback=explain_callback)
|
||||
|
||||
assert len(received_calls) == 1
|
||||
assert received_calls[0] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_explain_events_all_forward_triples(self):
|
||||
"""
|
||||
Each explain event in a session should forward its own triples.
|
||||
"""
|
||||
q_triples = sample_question_triples()
|
||||
f_triples = sample_focus_triples()
|
||||
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=q_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:focus:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=f_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append({
|
||||
"explain_id": explain_id,
|
||||
"explain_triples": explain_triples,
|
||||
})
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
await client.rag(query="test", explain_callback=explain_callback)
|
||||
|
||||
assert len(received_calls) == 2
|
||||
assert received_calls[0]["explain_id"] == "urn:trustgraph:question:abc"
|
||||
assert received_calls[0]["explain_triples"] == q_triples
|
||||
assert received_calls[1]["explain_id"] == "urn:trustgraph:focus:abc"
|
||||
assert received_calls[1]["explain_triples"] == f_triples
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_does_not_error(self):
|
||||
"""
|
||||
When no explain_callback is provided, explain events should be
|
||||
silently skipped without errors.
|
||||
"""
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_question_triples(),
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
result = await client.rag(query="test")
|
||||
assert result == "Answer."
|
||||
|
|
@ -0,0 +1,482 @@
|
|||
"""
|
||||
Integration test: run a full GraphRag.query() with mocked subsidiary clients
|
||||
and verify the explain_callback receives the complete provenance chain
|
||||
in the correct order with correct structure.
|
||||
|
||||
This tests the real query() method end-to-end, not just the triple builders.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
|
||||
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingMatch:
|
||||
"""Mimics the result from graph_embeddings_client.query()."""
|
||||
entity: Term
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# A tiny knowledge graph: 2 entities, 3 edges
|
||||
ENTITY_A = "http://example.com/QuantumComputing"
|
||||
ENTITY_B = "http://example.com/Physics"
|
||||
EDGE_1 = (ENTITY_A, "http://schema.org/relatedTo", ENTITY_B)
|
||||
EDGE_2 = (ENTITY_A, "http://schema.org/name", "Quantum Computing")
|
||||
EDGE_3 = (ENTITY_B, "http://schema.org/name", "Physics")
|
||||
|
||||
|
||||
def make_schema_triple(s, p, o):
|
||||
"""Create a SchemaTriple from string values."""
|
||||
return SchemaTriple(
|
||||
s=Term(type=IRI, iri=s),
|
||||
p=Term(type=IRI, iri=p),
|
||||
o=Term(type=IRI, iri=o) if o.startswith("http") else Term(type=LITERAL, value=o),
|
||||
)
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock clients that simulate a small knowledge graph query.
|
||||
|
||||
Client call sequence during query():
|
||||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. graph_embeddings_client.query(vector, ...) -> entity matches
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
|
||||
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
|
||||
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
|
||||
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
|
||||
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
prompt_responses = {}
|
||||
prompt_responses["extract-concepts"] = "quantum computing\nphysics"
|
||||
|
||||
# 2. Embedding vectors (simple fake vectors)
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# 3. Entity lookup - return our two entities
|
||||
graph_embeddings_client.query.return_value = [
|
||||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_A)),
|
||||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
|
||||
]
|
||||
|
||||
# 4. Triple queries (follow_edges_batch) - return our edges
|
||||
kg_triples = [
|
||||
make_schema_triple(*EDGE_1),
|
||||
make_schema_triple(*EDGE_2),
|
||||
make_schema_triple(*EDGE_3),
|
||||
]
|
||||
triples_client.query_stream.return_value = kg_triples
|
||||
|
||||
# 5. Label resolution - return entity as its own label (simplify)
|
||||
async def mock_label_query(s=None, p=None, o=None, limit=1,
|
||||
user=None, collection=None, g=None):
|
||||
return [] # No labels found, will fall back to URI
|
||||
triples_client.query.side_effect = mock_label_query
|
||||
|
||||
# 6+7. Edge scoring and reasoning: dynamically score/reason about
|
||||
# whatever edges the query method sends us, since edge IDs are computed
|
||||
# from str(Term) representations which include the full dataclass repr.
|
||||
synthesis_answer = "Quantum computing applies physics principles to computation."
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return prompt_responses["extract-concepts"]
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
elif template_id == "kg-synthesis":
|
||||
return synthesis_answer
|
||||
return ""
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagQueryProvenance:
|
||||
"""
|
||||
Run a real GraphRag.query() and verify the provenance chain emitted
|
||||
via explain_callback.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_five_events(self):
|
||||
"""query() should emit exactly 5 explain events."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0, # skip semantic pre-filter for simplicity
|
||||
)
|
||||
|
||||
assert len(events) == 5, (
|
||||
f"Expected 5 explain events (question, grounding, exploration, "
|
||||
f"focus, synthesis), got {len(events)}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_events_have_correct_types_in_order(self):
|
||||
"""
|
||||
Events should arrive as:
|
||||
question, grounding, exploration, focus, synthesis.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
TG_GRAPH_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_FOCUS,
|
||||
TG_SYNTHESIS,
|
||||
]
|
||||
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
uri = events[i]["explain_id"]
|
||||
triples = events[i]["triples"]
|
||||
assert has_type(triples, uri, expected_type), (
|
||||
f"Event {i} (uri={uri}) should have type {expected_type}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_derivation_chain_links_correctly(self):
|
||||
"""
|
||||
Each event's URI should link to the previous via wasDerivedFrom:
|
||||
grounding → question → (none)
|
||||
exploration → grounding
|
||||
focus → exploration
|
||||
synthesis → focus
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
|
||||
# question has no parent
|
||||
assert derived_from(all_triples, uris[0]) is None
|
||||
|
||||
# grounding → question
|
||||
assert derived_from(all_triples, uris[1]) == uris[0]
|
||||
|
||||
# exploration → grounding
|
||||
assert derived_from(all_triples, uris[2]) == uris[1]
|
||||
|
||||
# focus → exploration
|
||||
assert derived_from(all_triples, uris[3]) == uris[2]
|
||||
|
||||
# synthesis → focus
|
||||
assert derived_from(all_triples, uris[4]) == uris[3]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_event_carries_query_text(self):
|
||||
"""The question event should contain the original query string."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
t = find_triple(q_triples, TG_QUERY, q_uri)
|
||||
assert t is not None
|
||||
assert t.o.value == "What is quantum computing?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grounding_carries_concepts(self):
|
||||
"""The grounding event should list extracted concepts."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
gnd_triples = events[1]["triples"]
|
||||
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
|
||||
concept_values = {t.o.value for t in concepts}
|
||||
assert "quantum computing" in concept_values
|
||||
assert "physics" in concept_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_edge_count(self):
|
||||
"""The exploration event should report how many edges were found."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
t = find_triple(exp_triples, TG_EDGE_COUNT, exp_uri)
|
||||
assert t is not None
|
||||
# Should be non-zero (we provided 3 edges, label edges filtered)
|
||||
assert int(t.o.value) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_focus_has_selected_edges_with_reasoning(self):
|
||||
"""
|
||||
The focus event should carry selected edges as quoted triples
|
||||
with reasoning text.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
foc_uri = events[3]["explain_id"]
|
||||
foc_triples = events[3]["triples"]
|
||||
|
||||
# Should have selected edges
|
||||
selected = find_triples(foc_triples, TG_SELECTED_EDGE, foc_uri)
|
||||
assert len(selected) > 0, "Focus should have at least one selected edge"
|
||||
|
||||
# Each edge selection should have a quoted triple
|
||||
edge_t = find_triples(foc_triples, TG_EDGE)
|
||||
assert len(edge_t) > 0, "Focus should have tg:edge with quoted triples"
|
||||
for t in edge_t:
|
||||
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
|
||||
|
||||
# Should have reasoning
|
||||
reasoning = find_triples(foc_triples, TG_REASONING)
|
||||
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
"""The synthesis event should have tg:Answer type."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
syn_uri = events[4]["explain_id"]
|
||||
syn_triples = events[4]["triples"]
|
||||
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
|
||||
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_answer_text(self):
|
||||
"""query() should still return the synthesised answer."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
result = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parent_uri_links_question_to_parent(self):
|
||||
"""When parent_uri is provided, question should derive from it."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
parent = "urn:trustgraph:agent:iteration:xyz"
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
parent_uri=parent,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
assert derived_from(q_triples, q_uri) == parent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
"""query() without explain_callback should return answer normally."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
"""All emitted triples should be in the urn:graph:retrieval graph."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
for t in event["triples"]:
|
||||
assert t.g == "urn:graph:retrieval", (
|
||||
f"Triple {t.s.iri} {t.p.iri} should be in "
|
||||
f"urn:graph:retrieval, got {t.g}"
|
||||
)
|
||||
|
|
@ -28,7 +28,7 @@ class TestRequestTranslation:
|
|||
}
|
||||
|
||||
# Translate to Pulsar
|
||||
pulsar_msg = translator.to_pulsar(api_data)
|
||||
pulsar_msg = translator.decode(api_data)
|
||||
|
||||
assert pulsar_msg.operation == "schema-selection"
|
||||
assert pulsar_msg.sample == "test data sample"
|
||||
|
|
@ -46,7 +46,7 @@ class TestRequestTranslation:
|
|||
"options": {"delimiter": ","}
|
||||
}
|
||||
|
||||
pulsar_msg = translator.to_pulsar(api_data)
|
||||
pulsar_msg = translator.decode(api_data)
|
||||
|
||||
assert pulsar_msg.operation == "generate-descriptor"
|
||||
assert pulsar_msg.sample == "csv data"
|
||||
|
|
@ -70,7 +70,7 @@ class TestResponseTranslation:
|
|||
)
|
||||
|
||||
# Translate to API format
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "schema-selection"
|
||||
assert api_data["schema-matches"] == ["products", "inventory", "catalog"]
|
||||
|
|
@ -86,7 +86,7 @@ class TestResponseTranslation:
|
|||
error=None
|
||||
)
|
||||
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "schema-selection"
|
||||
assert api_data["schema-matches"] == []
|
||||
|
|
@ -103,7 +103,7 @@ class TestResponseTranslation:
|
|||
error=None
|
||||
)
|
||||
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "detect-type"
|
||||
assert api_data["detected-type"] == "xml"
|
||||
|
|
@ -123,7 +123,7 @@ class TestResponseTranslation:
|
|||
)
|
||||
)
|
||||
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "schema-selection"
|
||||
# Error objects are typically handled separately by the gateway
|
||||
|
|
@ -146,7 +146,7 @@ class TestResponseTranslation:
|
|||
error=None
|
||||
)
|
||||
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "diagnose"
|
||||
assert api_data["detected-type"] == "csv"
|
||||
|
|
@ -165,7 +165,7 @@ class TestResponseTranslation:
|
|||
error=None
|
||||
)
|
||||
|
||||
api_data, is_final = translator.from_response_with_completion(pulsar_response)
|
||||
api_data, is_final = translator.encode_with_completion(pulsar_response)
|
||||
|
||||
assert is_final is True # Structured-diag responses are always final
|
||||
assert api_data["operation"] == "schema-selection"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue