mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-15 18:35:15 +02:00
Feature/streaming triples (#676)
* Steaming triples * Also GraphRAG service uses this * Updated tests
This commit is contained in:
parent
3c3e11bef5
commit
d2d71f859d
11 changed files with 542 additions and 116 deletions
|
|
@ -48,7 +48,7 @@ class TestGraphRagIntegration:
|
|||
client = AsyncMock()
|
||||
|
||||
# Mock different queries return different triples
|
||||
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
|
||||
async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20):
|
||||
# Mock label queries
|
||||
if p == "http://www.w3.org/2000/01/rdf-schema#label":
|
||||
if s == "http://trustgraph.ai/e/machine-learning":
|
||||
|
|
@ -76,7 +76,9 @@ class TestGraphRagIntegration:
|
|||
|
||||
return []
|
||||
|
||||
client.query.side_effect = query_side_effect
|
||||
client.query_stream.side_effect = query_stream_side_effect
|
||||
# Also mock query for label lookups (maybe_label uses query, not query_stream)
|
||||
client.query.side_effect = query_stream_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -137,7 +139,7 @@ class TestGraphRagIntegration:
|
|||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query.call_count > 0
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt with knowledge graph
|
||||
mock_prompt_client.kg_prompt.assert_called_once()
|
||||
|
|
@ -202,7 +204,7 @@ class TestGraphRagIntegration:
|
|||
"""Test GraphRAG handles empty knowledge graph gracefully"""
|
||||
# Arrange
|
||||
mock_graph_embeddings_client.query.return_value = [] # No entities found
|
||||
mock_triples_client.query.return_value = [] # No triples found
|
||||
mock_triples_client.query_stream.return_value = [] # No triples found
|
||||
|
||||
# Act
|
||||
result = await graph_rag.query(
|
||||
|
|
@ -231,7 +233,7 @@ class TestGraphRagIntegration:
|
|||
collection="test_collection"
|
||||
)
|
||||
|
||||
first_call_count = mock_triples_client.query.call_count
|
||||
first_call_count = mock_triples_client.query_stream.call_count
|
||||
mock_triples_client.reset_mock()
|
||||
|
||||
# Second identical query
|
||||
|
|
@ -241,7 +243,7 @@ class TestGraphRagIntegration:
|
|||
collection="test_collection"
|
||||
)
|
||||
|
||||
second_call_count = mock_triples_client.query.call_count
|
||||
second_call_count = mock_triples_client.query_stream.call_count
|
||||
|
||||
# Assert - Second query should make fewer triple queries due to caching
|
||||
# Note: This is a weak assertion because caching behavior depends on
|
||||
|
|
|
|||
|
|
@ -193,15 +193,17 @@ class TestQuery:
|
|||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
|
||||
# Mock EntityMatch objects with entity that has string representation
|
||||
# Mock EntityMatch objects with entity as Term-like object
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
||||
mock_entity1.type = "i" # IRI type
|
||||
mock_entity1.iri = "entity1"
|
||||
mock_match1 = MagicMock()
|
||||
mock_match1.entity = mock_entity1
|
||||
mock_match1.score = 0.95
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
||||
mock_entity2.type = "i" # IRI type
|
||||
mock_entity2.iri = "entity2"
|
||||
mock_match2 = MagicMock()
|
||||
mock_match2.entity = mock_entity2
|
||||
mock_match2.score = 0.85
|
||||
|
|
@ -363,10 +365,10 @@ class TestQuery:
|
|||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
# Setup query responses for s=ent, p=ent, o=ent patterns
|
||||
mock_triples_client.query.side_effect = [
|
||||
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
|
||||
mock_triples_client.query_stream.side_effect = [
|
||||
[mock_triple1], # s=ent, p=None, o=None
|
||||
[mock_triple2], # s=None, p=ent, o=None
|
||||
[mock_triple2], # s=None, p=ent, o=None
|
||||
[mock_triple3], # s=None, p=None, o=ent
|
||||
]
|
||||
|
||||
|
|
@ -384,20 +386,20 @@ class TestQuery:
|
|||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify all three query patterns were called
|
||||
assert mock_triples_client.query.call_count == 3
|
||||
|
||||
# Verify query calls
|
||||
mock_triples_client.query.assert_any_call(
|
||||
assert mock_triples_client.query_stream.call_count == 3
|
||||
|
||||
# Verify query_stream calls
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
)
|
||||
mock_triples_client.query.assert_any_call(
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
)
|
||||
mock_triples_client.query.assert_any_call(
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
)
|
||||
|
||||
# Verify subgraph contains discovered triples
|
||||
|
|
@ -427,9 +429,9 @@ class TestQuery:
|
|||
# Call follow_edges with path_length=0
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
|
||||
# Verify no queries were made
|
||||
mock_triples_client.query.assert_not_called()
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
|
||||
# Verify subgraph remains empty
|
||||
assert subgraph == set()
|
||||
|
|
@ -456,9 +458,9 @@ class TestQuery:
|
|||
|
||||
# Call follow_edges
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
|
||||
# Verify no queries were made due to size limit
|
||||
mock_triples_client.query.assert_not_called()
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
|
||||
# Verify subgraph unchanged
|
||||
assert len(subgraph) == 3
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue