Feature/streaming triples (#676)

* Steaming triples

* Also GraphRAG service uses this

* Updated tests
This commit is contained in:
cybermaggedon 2026-03-09 15:46:33 +00:00 committed by GitHub
parent 3c3e11bef5
commit d2d71f859d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 542 additions and 116 deletions

View file

@ -48,7 +48,7 @@ class TestGraphRagIntegration:
client = AsyncMock()
# Mock different queries return different triples
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20):
# Mock label queries
if p == "http://www.w3.org/2000/01/rdf-schema#label":
if s == "http://trustgraph.ai/e/machine-learning":
@ -76,7 +76,9 @@ class TestGraphRagIntegration:
return []
client.query.side_effect = query_side_effect
client.query_stream.side_effect = query_stream_side_effect
# Also mock query for label lookups (maybe_label uses query, not query_stream)
client.query.side_effect = query_stream_side_effect
return client
@pytest.fixture
@ -137,7 +139,7 @@ class TestGraphRagIntegration:
assert call_args.kwargs['collection'] == collection
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query.call_count > 0
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt with knowledge graph
mock_prompt_client.kg_prompt.assert_called_once()
@ -202,7 +204,7 @@ class TestGraphRagIntegration:
"""Test GraphRAG handles empty knowledge graph gracefully"""
# Arrange
mock_graph_embeddings_client.query.return_value = [] # No entities found
mock_triples_client.query.return_value = [] # No triples found
mock_triples_client.query_stream.return_value = [] # No triples found
# Act
result = await graph_rag.query(
@ -231,7 +233,7 @@ class TestGraphRagIntegration:
collection="test_collection"
)
first_call_count = mock_triples_client.query.call_count
first_call_count = mock_triples_client.query_stream.call_count
mock_triples_client.reset_mock()
# Second identical query
@ -241,7 +243,7 @@ class TestGraphRagIntegration:
collection="test_collection"
)
second_call_count = mock_triples_client.query.call_count
second_call_count = mock_triples_client.query_stream.call_count
# Assert - Second query should make fewer triple queries due to caching
# Note: This is a weak assertion because caching behavior depends on

View file

@ -193,15 +193,17 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock EntityMatch objects with entity that has string representation
# Mock EntityMatch objects with entity as Term-like object
mock_entity1 = MagicMock()
mock_entity1.__str__ = MagicMock(return_value="entity1")
mock_entity1.type = "i" # IRI type
mock_entity1.iri = "entity1"
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock()
mock_entity2.__str__ = MagicMock(return_value="entity2")
mock_entity2.type = "i" # IRI type
mock_entity2.iri = "entity2"
mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
@ -363,10 +365,10 @@ class TestQuery:
mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
# Setup query responses for s=ent, p=ent, o=ent patterns
mock_triples_client.query.side_effect = [
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent, p=None, o=None
[mock_triple2], # s=None, p=ent, o=None
[mock_triple2], # s=None, p=ent, o=None
[mock_triple3], # s=None, p=None, o=ent
]
@ -384,20 +386,20 @@ class TestQuery:
await query.follow_edges("entity1", subgraph, path_length=1)
# Verify all three query patterns were called
assert mock_triples_client.query.call_count == 3
# Verify query calls
mock_triples_client.query.assert_any_call(
assert mock_triples_client.query_stream.call_count == 3
# Verify query_stream calls
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
user="test_user", collection="test_collection"
user="test_user", collection="test_collection", batch_size=20
)
mock_triples_client.query.assert_any_call(
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
user="test_user", collection="test_collection"
user="test_user", collection="test_collection", batch_size=20
)
mock_triples_client.query.assert_any_call(
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
user="test_user", collection="test_collection"
user="test_user", collection="test_collection", batch_size=20
)
# Verify subgraph contains discovered triples
@ -427,9 +429,9 @@ class TestQuery:
# Call follow_edges with path_length=0
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0)
# Verify no queries were made
mock_triples_client.query.assert_not_called()
mock_triples_client.query_stream.assert_not_called()
# Verify subgraph remains empty
assert subgraph == set()
@ -456,9 +458,9 @@ class TestQuery:
# Call follow_edges
await query.follow_edges("entity1", subgraph, path_length=1)
# Verify no queries were made due to size limit
mock_triples_client.query.assert_not_called()
mock_triples_client.query_stream.assert_not_called()
# Verify subgraph unchanged
assert len(subgraph) == 3