mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
release/v1.4 -> master (#548)
This commit is contained in:
parent
3ec2cd54f9
commit
2bd68ed7f4
94 changed files with 8571 additions and 1740 deletions
|
|
@ -34,7 +34,9 @@ class TestGraphRag:
|
|||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
|
|
@ -59,7 +61,9 @@ class TestGraphRag:
|
|||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is True
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
|
||||
class TestQuery:
|
||||
|
|
@ -228,8 +232,11 @@ class TestQuery:
|
|||
"""Test Query.maybe_label method with cached label"""
|
||||
# Create mock GraphRag with label cache
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {"entity1": "Entity One Label"}
|
||||
|
||||
# Create mock LRUCacheWithTTL
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = "Entity One Label"
|
||||
mock_rag.label_cache = mock_cache
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -237,27 +244,32 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Call maybe_label with cached entity
|
||||
result = await query.maybe_label("entity1")
|
||||
|
||||
|
||||
# Verify cached label is returned
|
||||
assert result == "Entity One Label"
|
||||
# Verify cache was checked with proper key format (user:collection:entity)
|
||||
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
"""Test Query.maybe_label method with database label lookup"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
|
||||
# Mock triple result with label
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Human Readable Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -265,10 +277,10 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("http://example.com/entity")
|
||||
|
||||
|
||||
# Verify triples client was called correctly
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="http://example.com/entity",
|
||||
|
|
@ -278,17 +290,21 @@ class TestQuery:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result and cache update
|
||||
|
||||
# Verify result and cache update with proper key
|
||||
assert result == "Human Readable Label"
|
||||
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
|
||||
cache_key = "test_user:test_collection:http://example.com/entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_no_label_found(self):
|
||||
"""Test Query.maybe_label method when no label is found"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
|
|
@ -318,7 +334,8 @@ class TestQuery:
|
|||
|
||||
# Verify result is entity itself and cache is updated
|
||||
assert result == "unlabeled_entity"
|
||||
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
|
||||
cache_key = "test_user:test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
|
|
@ -441,40 +458,40 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
||||
# Create mock Query that patches get_entities and follow_edges
|
||||
# Create mock Query that patches get_entities and follow_edges_batch
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
)
|
||||
|
||||
|
||||
# Mock get_entities to return test entities
|
||||
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
||||
|
||||
# Mock follow_edges to add triples to subgraph
|
||||
async def mock_follow_edges(ent, subgraph, path_length):
|
||||
subgraph.add((ent, "predicate", "object"))
|
||||
|
||||
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
|
||||
|
||||
|
||||
# Mock follow_edges_batch to return test triples
|
||||
query.follow_edges_batch = AsyncMock(return_value={
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
})
|
||||
|
||||
# Call get_subgraph
|
||||
result = await query.get_subgraph("test query")
|
||||
|
||||
|
||||
# Verify get_entities was called
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
|
||||
# Verify follow_edges was called for each entity
|
||||
assert query.follow_edges.call_count == 2
|
||||
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
|
||||
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
|
||||
|
||||
# Verify result is list format
|
||||
|
||||
# Verify follow_edges_batch was called with entities and max_path_length
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
# Verify result is list format and contains expected triples
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert ("entity1", "predicate1", "object1") in result
|
||||
assert ("entity2", "predicate2", "object2") in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue