trustgraph/tests/integration/test_graph_rag_integration.py
cybermaggedon 14e49d83c7
Expose LLM token usage across all service layers (#782)
Expose LLM token usage (in_token, out_token, model) across all
service layers

Propagate token counts from LLM services through the prompt,
text-completion, graph-RAG, document-RAG, and agent orchestrator
pipelines to the API gateway and Python SDK. All fields are Optional
— None means "not available", distinguishing from a real zero count.

Key changes:

- Schema: Add in_token/out_token/model to TextCompletionResponse,
  PromptResponse, GraphRagResponse, DocumentRagResponse,
  AgentResponse

- TextCompletionClient: New TextCompletionResult return type. Split
  into text_completion() (non-streaming) and
  text_completion_stream() (streaming with per-chunk handler
  callback)

- PromptClient: New PromptResult with response_type
  (text/json/jsonl), typed fields (text/object/objects), and token
  usage. All callers updated.

- RAG services: Accumulate token usage across all prompt calls
  (extract-concepts, edge-scoring, edge-reasoning,
  synthesis). Non-streaming path sends single combined response
  instead of chunk + end_of_session.

- Agent orchestrator: UsageTracker accumulates tokens across
  meta-router, pattern prompt calls, and react reasoning. Attached
  to end_of_dialog.

- Translators: Encode token fields when not None (is not None, not truthy)

- Python SDK: RAG and text-completion methods return
  TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with
  token fields (streaming)

- CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt,
  tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
2026-04-13 14:38:34 +01:00

314 lines
13 KiB
Python

"""
Integration tests for GraphRAG retrieval system
These tests verify the end-to-end functionality of the GraphRAG system,
testing the coordination between embeddings, graph retrieval, triple querying, and prompt services.
Following the TEST_STRATEGY.md approach for integration testing.
NOTE: This is the first integration test file for GraphRAG (previously had only unit tests).
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
@pytest.mark.integration
class TestGraphRagIntegration:
"""Integration tests for GraphRAG system coordination"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client that returns realistic vector embeddings"""
client = AsyncMock()
# New batch format: [[[vectors_for_text1], ...]]
# One text input returns one vector set containing one vector
client.embed.return_value = [
[
[0.1, 0.2, 0.3, 0.4, 0.5], # Vector for text
]
]
return client
@pytest.fixture
def mock_graph_embeddings_client(self):
"""Mock graph embeddings client that returns realistic entities"""
client = AsyncMock()
client.query.return_value = [
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
]
return client
@pytest.fixture
def mock_triples_client(self):
"""Mock triples client that returns realistic knowledge graph triples"""
client = AsyncMock()
# Mock different queries return different triples
async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20):
# Mock label queries
if p == "http://www.w3.org/2000/01/rdf-schema#label":
if s == "http://trustgraph.ai/e/machine-learning":
return [MagicMock(s=s, p=p, o="Machine Learning")]
elif s == "http://trustgraph.ai/e/artificial-intelligence":
return [MagicMock(s=s, p=p, o="Artificial Intelligence")]
elif s == "http://trustgraph.ai/e/neural-networks":
return [MagicMock(s=s, p=p, o="Neural Networks")]
return []
# Mock relationship queries
if s == "http://trustgraph.ai/e/machine-learning":
return [
MagicMock(
s="http://trustgraph.ai/e/machine-learning",
p="http://trustgraph.ai/is_subset_of",
o="http://trustgraph.ai/e/artificial-intelligence"
),
MagicMock(
s="http://trustgraph.ai/e/machine-learning",
p="http://www.w3.org/2000/01/rdf-schema#label",
o="Machine Learning"
)
]
return []
client.query_stream.side_effect = query_stream_side_effect
# Also mock query for label lookups (maybe_label uses query, not query_stream)
client.query.side_effect = query_stream_side_effect
return client
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses for two-step process"""
client = AsyncMock()
# Mock responses for the multi-step process:
# 1. extract-concepts extracts key concepts from the query
# 2. kg-edge-scoring scores edges for relevance
# 3. kg-edge-reasoning provides reasoning for selected edges
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
return PromptResult(
response_type="text",
text=(
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
)
return PromptResult(response_type="text", text="")
client.prompt.side_effect = mock_prompt
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
prompt_client=mock_prompt_client,
verbose=True
)
@pytest.mark.asyncio
async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client,
mock_graph_embeddings_client, mock_triples_client,
mock_prompt_client):
"""Test complete GraphRAG pipeline from query to response with real-time provenance"""
# Arrange
query = "What is machine learning?"
user = "test_user"
collection = "ml_knowledge"
entity_limit = 50
triple_limit = 30
# Collect provenance events
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
# Act
response = await graph_rag.query(
query=query,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit,
explain_callback=collect_provenance
)
# Assert - Verify service coordination
# 1. Should compute embeddings for query (now expects list of texts)
mock_embeddings_client.embed.assert_called_once_with([query])
# 2. Should query graph embeddings to find relevant entities
mock_graph_embeddings_client.query.assert_called_once()
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
assert mock_prompt_client.prompt.call_count == 4
# Verify final response
response, usage = response
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
assert len(provenance_events) == 5
for triples, prov_id in provenance_events:
assert isinstance(triples, list)
assert prov_id.startswith("urn:trustgraph:")
@pytest.mark.asyncio
async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client,
mock_graph_embeddings_client):
"""Test GraphRAG with various entity and triple limits"""
# Arrange
query = "Explain neural networks"
test_configs = [
{"entity_limit": 10, "triple_limit": 10},
{"entity_limit": 50, "triple_limit": 30},
{"entity_limit": 100, "triple_limit": 100},
]
for config in test_configs:
# Reset mocks
mock_embeddings_client.reset_mock()
mock_graph_embeddings_client.reset_mock()
# Act
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection",
entity_limit=config["entity_limit"],
triple_limit=config["triple_limit"]
)
# Assert
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == config["entity_limit"]
@pytest.mark.asyncio
async def test_graph_rag_error_propagation(self, graph_rag, mock_embeddings_client):
"""Test that errors from underlying services are properly propagated"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service error")
# Act & Assert
with pytest.raises(Exception) as exc_info:
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection"
)
assert "Embeddings service error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_graph_rag_with_empty_knowledge_graph(self, graph_rag, mock_graph_embeddings_client,
mock_triples_client, mock_prompt_client):
"""Test GraphRAG handles empty knowledge graph gracefully"""
# Arrange
mock_graph_embeddings_client.query.return_value = [] # No entities found
mock_triples_client.query_stream.return_value = [] # No triples found
# Collect provenance
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
# Act
response = await graph_rag.query(
query="unknown topic",
user="test_user",
collection="test_collection",
explain_callback=collect_provenance
)
# Assert
# Should still call prompt client
assert response is not None
# Provenance should still be emitted (5 events)
assert len(provenance_events) == 5
@pytest.mark.asyncio
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):
"""Test that label lookups are cached to reduce redundant queries"""
# Arrange
query = "What is machine learning?"
# First query
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection"
)
first_call_count = mock_triples_client.query_stream.call_count
mock_triples_client.reset_mock()
# Second identical query
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection"
)
second_call_count = mock_triples_client.query_stream.call_count
# Assert - Second query should make fewer triple queries due to caching
# Note: This is a weak assertion because caching behavior depends on
# implementation details, but it verifies the concept
assert second_call_count >= 0 # Should complete without errors
@pytest.mark.asyncio
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client):
"""Test that different users/collections are properly isolated"""
# Arrange
query = "test query"
user1, collection1 = "user1", "collection1"
user2, collection2 = "user2", "collection2"
# Act
await graph_rag.query(query=query, user=user1, collection=collection1)
await graph_rag.query(query=query, user=user2, collection=collection2)
# Assert - Both users should have separate queries
assert mock_graph_embeddings_client.query.call_count == 2
# Verify first call
first_call = mock_graph_embeddings_client.query.call_args_list[0]
assert first_call.kwargs['user'] == user1
assert first_call.kwargs['collection'] == collection1
# Verify second call
second_call = mock_graph_embeddings_client.query.call_args_list[1]
assert second_call.kwargs['user'] == user2
assert second_call.kwargs['collection'] == collection2