release/v1.4 -> master (#548)

This commit is contained in:
cybermaggedon 2025-10-06 17:54:26 +01:00 committed by GitHub
parent 3ec2cd54f9
commit 2bd68ed7f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
94 changed files with 8571 additions and 1740 deletions

View file

@ -34,7 +34,9 @@ class TestGraphRag:
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is False # Default value
assert graph_rag.label_cache == {} # Empty cache initially
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
@ -59,7 +61,9 @@ class TestGraphRag:
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is True
assert graph_rag.label_cache == {} # Empty cache initially
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
class TestQuery:
@ -228,8 +232,11 @@ class TestQuery:
"""Test Query.maybe_label method with cached label"""
# Create mock GraphRag with label cache
mock_rag = MagicMock()
mock_rag.label_cache = {"entity1": "Entity One Label"}
# Create mock LRUCacheWithTTL
mock_cache = MagicMock()
mock_cache.get.return_value = "Entity One Label"
mock_rag.label_cache = mock_cache
# Initialize Query
query = Query(
rag=mock_rag,
@ -237,27 +244,32 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call maybe_label with cached entity
result = await query.maybe_label("entity1")
# Verify cached label is returned
assert result == "Entity One Label"
# Verify cache was checked with proper key format (user:collection:entity)
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
"""Test Query.maybe_label method with database label lookup"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
mock_rag.label_cache = {} # Empty cache
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock triple result with label
mock_triple = MagicMock()
mock_triple.o = "Human Readable Label"
mock_triples_client.query.return_value = [mock_triple]
# Initialize Query
query = Query(
rag=mock_rag,
@ -265,10 +277,10 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call maybe_label
result = await query.maybe_label("http://example.com/entity")
# Verify triples client was called correctly
mock_triples_client.query.assert_called_once_with(
s="http://example.com/entity",
@ -278,17 +290,21 @@ class TestQuery:
user="test_user",
collection="test_collection"
)
# Verify result and cache update
# Verify result and cache update with proper key
assert result == "Human Readable Label"
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@pytest.mark.asyncio
async def test_maybe_label_with_no_label_found(self):
"""Test Query.maybe_label method when no label is found"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
mock_rag.label_cache = {} # Empty cache
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
@ -318,7 +334,8 @@ class TestQuery:
# Verify result is entity itself and cache is updated
assert result == "unlabeled_entity"
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
@ -441,40 +458,40 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
# Create mock Query that patches get_entities and follow_edges
# Create mock Query that patches get_entities and follow_edges_batch
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
user="test_user",
user="test_user",
collection="test_collection",
verbose=False,
max_path_length=1
)
# Mock get_entities to return test entities
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
# Mock follow_edges to add triples to subgraph
async def mock_follow_edges(ent, subgraph, path_length):
subgraph.add((ent, "predicate", "object"))
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
# Mock follow_edges_batch to return test triples
query.follow_edges_batch = AsyncMock(return_value={
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
})
# Call get_subgraph
result = await query.get_subgraph("test query")
# Verify get_entities was called
query.get_entities.assert_called_once_with("test query")
# Verify follow_edges was called for each entity
assert query.follow_edges.call_count == 2
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
# Verify result is list format
# Verify follow_edges_batch was called with entities and max_path_length
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
# Verify result is list format and contains expected triples
assert isinstance(result, list)
assert len(result) == 2
assert ("entity1", "predicate1", "object1") in result
assert ("entity2", "predicate2", "object2") in result
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):