mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Feature/streaming triples (#676)
* Steaming triples * Also GraphRAG service uses this * Updated tests
This commit is contained in:
parent
3c3e11bef5
commit
d2d71f859d
11 changed files with 542 additions and 116 deletions
|
|
@ -48,7 +48,7 @@ class TestGraphRagIntegration:
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
|
|
||||||
# Mock different queries return different triples
|
# Mock different queries return different triples
|
||||||
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
|
async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20):
|
||||||
# Mock label queries
|
# Mock label queries
|
||||||
if p == "http://www.w3.org/2000/01/rdf-schema#label":
|
if p == "http://www.w3.org/2000/01/rdf-schema#label":
|
||||||
if s == "http://trustgraph.ai/e/machine-learning":
|
if s == "http://trustgraph.ai/e/machine-learning":
|
||||||
|
|
@ -76,7 +76,9 @@ class TestGraphRagIntegration:
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
client.query.side_effect = query_side_effect
|
client.query_stream.side_effect = query_stream_side_effect
|
||||||
|
# Also mock query for label lookups (maybe_label uses query, not query_stream)
|
||||||
|
client.query.side_effect = query_stream_side_effect
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -137,7 +139,7 @@ class TestGraphRagIntegration:
|
||||||
assert call_args.kwargs['collection'] == collection
|
assert call_args.kwargs['collection'] == collection
|
||||||
|
|
||||||
# 3. Should query triples to build knowledge subgraph
|
# 3. Should query triples to build knowledge subgraph
|
||||||
assert mock_triples_client.query.call_count > 0
|
assert mock_triples_client.query_stream.call_count > 0
|
||||||
|
|
||||||
# 4. Should call prompt with knowledge graph
|
# 4. Should call prompt with knowledge graph
|
||||||
mock_prompt_client.kg_prompt.assert_called_once()
|
mock_prompt_client.kg_prompt.assert_called_once()
|
||||||
|
|
@ -202,7 +204,7 @@ class TestGraphRagIntegration:
|
||||||
"""Test GraphRAG handles empty knowledge graph gracefully"""
|
"""Test GraphRAG handles empty knowledge graph gracefully"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_graph_embeddings_client.query.return_value = [] # No entities found
|
mock_graph_embeddings_client.query.return_value = [] # No entities found
|
||||||
mock_triples_client.query.return_value = [] # No triples found
|
mock_triples_client.query_stream.return_value = [] # No triples found
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await graph_rag.query(
|
result = await graph_rag.query(
|
||||||
|
|
@ -231,7 +233,7 @@ class TestGraphRagIntegration:
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
first_call_count = mock_triples_client.query.call_count
|
first_call_count = mock_triples_client.query_stream.call_count
|
||||||
mock_triples_client.reset_mock()
|
mock_triples_client.reset_mock()
|
||||||
|
|
||||||
# Second identical query
|
# Second identical query
|
||||||
|
|
@ -241,7 +243,7 @@ class TestGraphRagIntegration:
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
second_call_count = mock_triples_client.query.call_count
|
second_call_count = mock_triples_client.query_stream.call_count
|
||||||
|
|
||||||
# Assert - Second query should make fewer triple queries due to caching
|
# Assert - Second query should make fewer triple queries due to caching
|
||||||
# Note: This is a weak assertion because caching behavior depends on
|
# Note: This is a weak assertion because caching behavior depends on
|
||||||
|
|
|
||||||
|
|
@ -193,15 +193,17 @@ class TestQuery:
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||||
|
|
||||||
# Mock EntityMatch objects with entity that has string representation
|
# Mock EntityMatch objects with entity as Term-like object
|
||||||
mock_entity1 = MagicMock()
|
mock_entity1 = MagicMock()
|
||||||
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
mock_entity1.type = "i" # IRI type
|
||||||
|
mock_entity1.iri = "entity1"
|
||||||
mock_match1 = MagicMock()
|
mock_match1 = MagicMock()
|
||||||
mock_match1.entity = mock_entity1
|
mock_match1.entity = mock_entity1
|
||||||
mock_match1.score = 0.95
|
mock_match1.score = 0.95
|
||||||
|
|
||||||
mock_entity2 = MagicMock()
|
mock_entity2 = MagicMock()
|
||||||
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
mock_entity2.type = "i" # IRI type
|
||||||
|
mock_entity2.iri = "entity2"
|
||||||
mock_match2 = MagicMock()
|
mock_match2 = MagicMock()
|
||||||
mock_match2.entity = mock_entity2
|
mock_match2.entity = mock_entity2
|
||||||
mock_match2.score = 0.85
|
mock_match2.score = 0.85
|
||||||
|
|
@ -363,10 +365,10 @@ class TestQuery:
|
||||||
mock_triple3 = MagicMock()
|
mock_triple3 = MagicMock()
|
||||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||||
|
|
||||||
# Setup query responses for s=ent, p=ent, o=ent patterns
|
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
|
||||||
mock_triples_client.query.side_effect = [
|
mock_triples_client.query_stream.side_effect = [
|
||||||
[mock_triple1], # s=ent, p=None, o=None
|
[mock_triple1], # s=ent, p=None, o=None
|
||||||
[mock_triple2], # s=None, p=ent, o=None
|
[mock_triple2], # s=None, p=ent, o=None
|
||||||
[mock_triple3], # s=None, p=None, o=ent
|
[mock_triple3], # s=None, p=None, o=ent
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -384,20 +386,20 @@ class TestQuery:
|
||||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||||
|
|
||||||
# Verify all three query patterns were called
|
# Verify all three query patterns were called
|
||||||
assert mock_triples_client.query.call_count == 3
|
assert mock_triples_client.query_stream.call_count == 3
|
||||||
|
|
||||||
# Verify query calls
|
# Verify query_stream calls
|
||||||
mock_triples_client.query.assert_any_call(
|
mock_triples_client.query_stream.assert_any_call(
|
||||||
s="entity1", p=None, o=None, limit=10,
|
s="entity1", p=None, o=None, limit=10,
|
||||||
user="test_user", collection="test_collection"
|
user="test_user", collection="test_collection", batch_size=20
|
||||||
)
|
)
|
||||||
mock_triples_client.query.assert_any_call(
|
mock_triples_client.query_stream.assert_any_call(
|
||||||
s=None, p="entity1", o=None, limit=10,
|
s=None, p="entity1", o=None, limit=10,
|
||||||
user="test_user", collection="test_collection"
|
user="test_user", collection="test_collection", batch_size=20
|
||||||
)
|
)
|
||||||
mock_triples_client.query.assert_any_call(
|
mock_triples_client.query_stream.assert_any_call(
|
||||||
s=None, p=None, o="entity1", limit=10,
|
s=None, p=None, o="entity1", limit=10,
|
||||||
user="test_user", collection="test_collection"
|
user="test_user", collection="test_collection", batch_size=20
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify subgraph contains discovered triples
|
# Verify subgraph contains discovered triples
|
||||||
|
|
@ -427,9 +429,9 @@ class TestQuery:
|
||||||
# Call follow_edges with path_length=0
|
# Call follow_edges with path_length=0
|
||||||
subgraph = set()
|
subgraph = set()
|
||||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||||
|
|
||||||
# Verify no queries were made
|
# Verify no queries were made
|
||||||
mock_triples_client.query.assert_not_called()
|
mock_triples_client.query_stream.assert_not_called()
|
||||||
|
|
||||||
# Verify subgraph remains empty
|
# Verify subgraph remains empty
|
||||||
assert subgraph == set()
|
assert subgraph == set()
|
||||||
|
|
@ -456,9 +458,9 @@ class TestQuery:
|
||||||
|
|
||||||
# Call follow_edges
|
# Call follow_edges
|
||||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||||
|
|
||||||
# Verify no queries were made due to size limit
|
# Verify no queries were made due to size limit
|
||||||
mock_triples_client.query.assert_not_called()
|
mock_triples_client.query_stream.assert_not_called()
|
||||||
|
|
||||||
# Verify subgraph unchanged
|
# Verify subgraph unchanged
|
||||||
assert len(subgraph) == 3
|
assert len(subgraph) == 3
|
||||||
|
|
|
||||||
|
|
@ -91,9 +91,18 @@ class SocketClient:
|
||||||
service: str,
|
service: str,
|
||||||
flow: Optional[str],
|
flow: Optional[str],
|
||||||
request: Dict[str, Any],
|
request: Dict[str, Any],
|
||||||
streaming: bool = False
|
streaming: bool = False,
|
||||||
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
|
streaming_raw: bool = False
|
||||||
"""Synchronous wrapper around async WebSocket communication"""
|
) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]:
|
||||||
|
"""Synchronous wrapper around async WebSocket communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service: Service name
|
||||||
|
flow: Flow ID (optional)
|
||||||
|
request: Request payload
|
||||||
|
streaming: Use parsed streaming (for agent/RAG chunk types)
|
||||||
|
streaming_raw: Use raw streaming (for data batches like triples)
|
||||||
|
"""
|
||||||
# Create event loop if needed
|
# Create event loop if needed
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
@ -105,12 +114,14 @@ class SocketClient:
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
if streaming:
|
if streaming_raw:
|
||||||
# For streaming, we need to return an iterator
|
# Raw streaming for data batches (triples, rows, etc.)
|
||||||
# Create a generator that runs async code
|
return self._streaming_generator_raw(service, flow, request, loop)
|
||||||
|
elif streaming:
|
||||||
|
# Parsed streaming for agent/RAG chunk types
|
||||||
return self._streaming_generator(service, flow, request, loop)
|
return self._streaming_generator(service, flow, request, loop)
|
||||||
else:
|
else:
|
||||||
# For non-streaming, just run the async code and return result
|
# Non-streaming single response
|
||||||
return loop.run_until_complete(self._send_request_async(service, flow, request))
|
return loop.run_until_complete(self._send_request_async(service, flow, request))
|
||||||
|
|
||||||
def _streaming_generator(
|
def _streaming_generator(
|
||||||
|
|
@ -120,7 +131,7 @@ class SocketClient:
|
||||||
request: Dict[str, Any],
|
request: Dict[str, Any],
|
||||||
loop: asyncio.AbstractEventLoop
|
loop: asyncio.AbstractEventLoop
|
||||||
) -> Iterator[StreamingChunk]:
|
) -> Iterator[StreamingChunk]:
|
||||||
"""Generator that yields streaming chunks"""
|
"""Generator that yields streaming chunks (for agent/RAG responses)"""
|
||||||
async_gen = self._send_request_async_streaming(service, flow, request)
|
async_gen = self._send_request_async_streaming(service, flow, request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -137,6 +148,74 @@ class SocketClient:
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _streaming_generator_raw(
|
||||||
|
self,
|
||||||
|
service: str,
|
||||||
|
flow: Optional[str],
|
||||||
|
request: Dict[str, Any],
|
||||||
|
loop: asyncio.AbstractEventLoop
|
||||||
|
) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Generator that yields raw response dicts (for data streaming like triples)"""
|
||||||
|
async_gen = self._send_request_async_streaming_raw(service, flow, request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = loop.run_until_complete(async_gen.__anext__())
|
||||||
|
yield data
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(async_gen.aclose())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _send_request_async_streaming_raw(
|
||||||
|
self,
|
||||||
|
service: str,
|
||||||
|
flow: Optional[str],
|
||||||
|
request: Dict[str, Any]
|
||||||
|
) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Async streaming that yields raw response dicts without parsing.
|
||||||
|
|
||||||
|
Used for data streaming (triples, rows, etc.) where responses are
|
||||||
|
just batches of data, not agent/RAG chunk types.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
self._request_counter += 1
|
||||||
|
request_id = f"req-{self._request_counter}"
|
||||||
|
|
||||||
|
ws_url = f"{self.url}/api/v1/socket"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"id": request_id,
|
||||||
|
"service": service,
|
||||||
|
"request": request
|
||||||
|
}
|
||||||
|
if flow:
|
||||||
|
message["flow"] = flow
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
await websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
async for raw_message in websocket:
|
||||||
|
response = json.loads(raw_message)
|
||||||
|
|
||||||
|
if response.get("id") != request_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
raise_from_error_dict(response["error"])
|
||||||
|
|
||||||
|
if "response" in response:
|
||||||
|
yield response["response"]
|
||||||
|
|
||||||
|
if response.get("complete"):
|
||||||
|
break
|
||||||
|
|
||||||
async def _send_request_async(
|
async def _send_request_async(
|
||||||
self,
|
self,
|
||||||
service: str,
|
service: str,
|
||||||
|
|
@ -790,6 +869,74 @@ class SocketFlowInstance:
|
||||||
|
|
||||||
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||||
|
|
||||||
|
def triples_query_stream(
|
||||||
|
self,
|
||||||
|
s: Optional[str] = None,
|
||||||
|
p: Optional[str] = None,
|
||||||
|
o: Optional[str] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
collection: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
batch_size: int = 20,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Iterator[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Query knowledge graph triples with streaming batches.
|
||||||
|
|
||||||
|
Yields batches of triples as they arrive, reducing time-to-first-result
|
||||||
|
and memory overhead for large result sets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: Subject URI (optional, use None for wildcard)
|
||||||
|
p: Predicate URI (optional, use None for wildcard)
|
||||||
|
o: Object URI or Literal (optional, use None for wildcard)
|
||||||
|
user: User/keyspace identifier (optional)
|
||||||
|
collection: Collection identifier (optional)
|
||||||
|
limit: Maximum results to return (default: 100)
|
||||||
|
batch_size: Triples per batch (default: 20)
|
||||||
|
**kwargs: Additional parameters passed to the service
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
List[Dict]: Batches of triples in wire format
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
socket = api.socket()
|
||||||
|
flow = socket.flow("default")
|
||||||
|
|
||||||
|
for batch in flow.triples_query_stream(
|
||||||
|
user="trustgraph",
|
||||||
|
collection="default"
|
||||||
|
):
|
||||||
|
for triple in batch:
|
||||||
|
print(triple["s"], triple["p"], triple["o"])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
request = {
|
||||||
|
"limit": limit,
|
||||||
|
"streaming": True,
|
||||||
|
"batch-size": batch_size,
|
||||||
|
}
|
||||||
|
if s is not None:
|
||||||
|
request["s"] = str(s)
|
||||||
|
if p is not None:
|
||||||
|
request["p"] = str(p)
|
||||||
|
if o is not None:
|
||||||
|
request["o"] = str(o)
|
||||||
|
if user is not None:
|
||||||
|
request["user"] = user
|
||||||
|
if collection is not None:
|
||||||
|
request["collection"] = collection
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
# Use raw streaming - yields response dicts directly without parsing
|
||||||
|
for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True):
|
||||||
|
# Response is {"response": [...triples...]} from translator
|
||||||
|
if isinstance(response, dict) and "response" in response:
|
||||||
|
yield response["response"]
|
||||||
|
else:
|
||||||
|
yield response
|
||||||
|
|
||||||
def rows_query(
|
def rows_query(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
|
||||||
|
|
@ -22,11 +22,19 @@ def to_value(x):
|
||||||
|
|
||||||
|
|
||||||
def from_value(x):
|
def from_value(x):
|
||||||
"""Convert Uri or Literal to schema Term."""
|
"""Convert Uri, Literal, or string to schema Term."""
|
||||||
if x is None:
|
if x is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(x, Uri):
|
if isinstance(x, Uri):
|
||||||
return Term(type=IRI, iri=str(x))
|
return Term(type=IRI, iri=str(x))
|
||||||
|
elif isinstance(x, Literal):
|
||||||
|
return Term(type=LITERAL, value=str(x))
|
||||||
|
elif isinstance(x, str):
|
||||||
|
# Detect IRIs by common prefixes
|
||||||
|
if x.startswith("http://") or x.startswith("https://") or x.startswith("urn:"):
|
||||||
|
return Term(type=IRI, iri=x)
|
||||||
|
else:
|
||||||
|
return Term(type=LITERAL, value=x)
|
||||||
else:
|
else:
|
||||||
return Term(type=LITERAL, value=str(x))
|
return Term(type=LITERAL, value=str(x))
|
||||||
|
|
||||||
|
|
@ -57,6 +65,61 @@ class TriplesClient(RequestResponse):
|
||||||
|
|
||||||
return triples
|
return triples
|
||||||
|
|
||||||
|
async def query_stream(self, s=None, p=None, o=None, limit=20,
|
||||||
|
user="trustgraph", collection="default",
|
||||||
|
batch_size=20, timeout=30,
|
||||||
|
batch_callback=None):
|
||||||
|
"""
|
||||||
|
Streaming triple query - calls callback for each batch as it arrives.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s, p, o: Triple pattern (None for wildcard)
|
||||||
|
limit: Maximum total triples to return
|
||||||
|
user: User/keyspace
|
||||||
|
collection: Collection name
|
||||||
|
batch_size: Triples per batch
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
batch_callback: Async callback(batch, is_final) called for each batch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Triple]: All triples (flattened) if no callback provided
|
||||||
|
"""
|
||||||
|
all_triples = []
|
||||||
|
|
||||||
|
async def recipient(resp):
|
||||||
|
if resp.error:
|
||||||
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
|
batch = [
|
||||||
|
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
|
||||||
|
for v in resp.triples
|
||||||
|
]
|
||||||
|
|
||||||
|
if batch_callback:
|
||||||
|
await batch_callback(batch, resp.is_final)
|
||||||
|
else:
|
||||||
|
all_triples.extend(batch)
|
||||||
|
|
||||||
|
return resp.is_final
|
||||||
|
|
||||||
|
await self.request(
|
||||||
|
TriplesQueryRequest(
|
||||||
|
s=from_value(s),
|
||||||
|
p=from_value(p),
|
||||||
|
o=from_value(o),
|
||||||
|
limit=limit,
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
streaming=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
recipient=recipient,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not batch_callback:
|
||||||
|
return all_triples
|
||||||
|
|
||||||
class TriplesClientSpec(RequestResponseSpec):
|
class TriplesClientSpec(RequestResponseSpec):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, request_name, response_name,
|
self, request_name, response_name,
|
||||||
|
|
|
||||||
|
|
@ -52,13 +52,23 @@ class TriplesQueryService(FlowProcessor):
|
||||||
|
|
||||||
logger.debug(f"Handling triples query request {id}...")
|
logger.debug(f"Handling triples query request {id}...")
|
||||||
|
|
||||||
triples = await self.query_triples(request)
|
if request.streaming:
|
||||||
|
# Streaming mode: send batches
|
||||||
logger.debug("Sending triples query response...")
|
async for batch, is_final in self.query_triples_stream(request):
|
||||||
r = TriplesQueryResponse(triples=triples, error=None)
|
r = TriplesQueryResponse(
|
||||||
await flow("response").send(r, properties={"id": id})
|
triples=batch,
|
||||||
|
error=None,
|
||||||
logger.debug("Triples query request completed")
|
is_final=is_final,
|
||||||
|
)
|
||||||
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
logger.debug("Triples query streaming completed")
|
||||||
|
else:
|
||||||
|
# Non-streaming mode: single response
|
||||||
|
triples = await self.query_triples(request)
|
||||||
|
logger.debug("Sending triples query response...")
|
||||||
|
r = TriplesQueryResponse(triples=triples, error=None)
|
||||||
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
logger.debug("Triples query request completed")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
@ -76,6 +86,24 @@ class TriplesQueryService(FlowProcessor):
|
||||||
|
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
||||||
|
async def query_triples_stream(self, request):
|
||||||
|
"""
|
||||||
|
Streaming query - yields (batch, is_final) tuples.
|
||||||
|
Default implementation batches results from query_triples.
|
||||||
|
Override for true streaming from backend.
|
||||||
|
"""
|
||||||
|
triples = await self.query_triples(request)
|
||||||
|
batch_size = request.batch_size if request.batch_size > 0 else 20
|
||||||
|
|
||||||
|
for i in range(0, len(triples), batch_size):
|
||||||
|
batch = triples[i:i + batch_size]
|
||||||
|
is_final = (i + batch_size >= len(triples))
|
||||||
|
yield batch, is_final
|
||||||
|
|
||||||
|
# Handle empty result
|
||||||
|
if len(triples) == 0:
|
||||||
|
yield [], True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,14 +23,18 @@ class TriplesQueryRequestTranslator(MessageTranslator):
|
||||||
g=g,
|
g=g,
|
||||||
limit=int(data.get("limit", 10000)),
|
limit=int(data.get("limit", 10000)),
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default")
|
collection=data.get("collection", "default"),
|
||||||
|
streaming=data.get("streaming", False),
|
||||||
|
batch_size=int(data.get("batch-size", 20)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]:
|
||||||
result = {
|
result = {
|
||||||
"limit": obj.limit,
|
"limit": obj.limit,
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection
|
"collection": obj.collection,
|
||||||
|
"streaming": obj.streaming,
|
||||||
|
"batch-size": obj.batch_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
if obj.s:
|
if obj.s:
|
||||||
|
|
@ -61,4 +65,4 @@ class TriplesQueryResponseTranslator(MessageTranslator):
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
return self.from_pulsar(obj), obj.is_final
|
||||||
|
|
@ -38,11 +38,14 @@ class TriplesQueryRequest:
|
||||||
o: Term | None = None
|
o: Term | None = None
|
||||||
g: str | None = None # Graph IRI. None=default graph, "*"=all graphs
|
g: str | None = None # Graph IRI. None=default graph, "*"=all graphs
|
||||||
limit: int = 0
|
limit: int = 0
|
||||||
|
streaming: bool = False # Enable streaming mode (multiple batched responses)
|
||||||
|
batch_size: int = 20 # Triples per batch in streaming mode
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TriplesQueryResponse:
|
class TriplesQueryResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
triples: list[Triple] = field(default_factory=list)
|
triples: list[Triple] = field(default_factory=list)
|
||||||
|
is_final: bool = True # False for intermediate batches in streaming mode
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Connects to the graph query service and dumps all graph edges in Turtle
|
Connects to the graph query service and dumps all graph edges in Turtle
|
||||||
format with RDF-star support for quoted triples.
|
format with RDF-star support for quoted triples.
|
||||||
|
Uses streaming mode for lower time-to-first-processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import rdflib
|
import rdflib
|
||||||
|
|
@ -9,66 +10,82 @@ import sys
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from trustgraph.api import Api, Uri
|
from trustgraph.api import Api
|
||||||
from trustgraph.knowledge import QuotedTriple
|
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
default_user = 'trustgraph'
|
default_user = 'trustgraph'
|
||||||
default_collection = 'default'
|
default_collection = 'default'
|
||||||
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
|
|
||||||
def value_to_rdflib(val):
|
def term_to_rdflib(term):
|
||||||
"""Convert a TrustGraph value to an rdflib term."""
|
"""Convert a wire-format term to an rdflib term."""
|
||||||
if isinstance(val, Uri):
|
if term is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
t = term.get("t", "")
|
||||||
|
|
||||||
|
if t == "i": # IRI
|
||||||
|
iri = term.get("i", "")
|
||||||
# Skip malformed URLs with spaces
|
# Skip malformed URLs with spaces
|
||||||
if " " in val:
|
if " " in iri:
|
||||||
return None
|
return None
|
||||||
return rdflib.term.URIRef(val)
|
return rdflib.term.URIRef(iri)
|
||||||
elif isinstance(val, QuotedTriple):
|
elif t == "l": # Literal
|
||||||
# RDF-star quoted triple
|
value = term.get("v", "")
|
||||||
s_term = value_to_rdflib(val.s)
|
datatype = term.get("d")
|
||||||
p_term = value_to_rdflib(val.p)
|
language = term.get("l")
|
||||||
o_term = value_to_rdflib(val.o)
|
if language:
|
||||||
|
return rdflib.term.Literal(value, lang=language)
|
||||||
|
elif datatype:
|
||||||
|
return rdflib.term.Literal(value, datatype=rdflib.term.URIRef(datatype))
|
||||||
|
else:
|
||||||
|
return rdflib.term.Literal(value)
|
||||||
|
elif t == "r": # Quoted triple (RDF-star)
|
||||||
|
triple = term.get("r", {})
|
||||||
|
s_term = term_to_rdflib(triple.get("s"))
|
||||||
|
p_term = term_to_rdflib(triple.get("p"))
|
||||||
|
o_term = term_to_rdflib(triple.get("o"))
|
||||||
if s_term is None or p_term is None or o_term is None:
|
if s_term is None or p_term is None or o_term is None:
|
||||||
return None
|
return None
|
||||||
# rdflib 6.x+ supports Triple as a term type
|
|
||||||
try:
|
try:
|
||||||
return rdflib.term.Triple((s_term, p_term, o_term))
|
return rdflib.term.Triple((s_term, p_term, o_term))
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# Fallback for older rdflib versions - represent as string
|
# Fallback for older rdflib versions
|
||||||
return rdflib.term.Literal(f"<<{val.s} {val.p} {val.o}>>")
|
return rdflib.term.Literal(f"<<{s_term} {p_term} {o_term}>>")
|
||||||
else:
|
else:
|
||||||
return rdflib.term.Literal(str(val))
|
# Fallback
|
||||||
|
return rdflib.term.Literal(str(term))
|
||||||
|
|
||||||
|
|
||||||
def show_graph(url, flow_id, user, collection):
|
def show_graph(url, flow_id, user, collection, limit, batch_size, token=None):
|
||||||
|
|
||||||
api = Api(url).flow().id(flow_id)
|
socket = Api(url, token=token).socket()
|
||||||
|
flow = socket.flow(flow_id)
|
||||||
rows = api.triples_query(
|
|
||||||
s=None, p=None, o=None,
|
|
||||||
user=user, collection=collection,
|
|
||||||
limit=10_000)
|
|
||||||
|
|
||||||
g = rdflib.Graph()
|
g = rdflib.Graph()
|
||||||
|
|
||||||
for row in rows:
|
try:
|
||||||
|
for batch in flow.triples_query_stream(
|
||||||
|
s=None, p=None, o=None,
|
||||||
|
user=user, collection=collection,
|
||||||
|
limit=limit,
|
||||||
|
batch_size=batch_size,
|
||||||
|
):
|
||||||
|
for triple in batch:
|
||||||
|
sv = term_to_rdflib(triple.get("s"))
|
||||||
|
pv = term_to_rdflib(triple.get("p"))
|
||||||
|
ov = term_to_rdflib(triple.get("o"))
|
||||||
|
|
||||||
sv = rdflib.term.URIRef(row.s)
|
if sv is None or pv is None or ov is None:
|
||||||
pv = rdflib.term.URIRef(row.p)
|
continue
|
||||||
ov = value_to_rdflib(row.o)
|
|
||||||
|
|
||||||
if ov is None:
|
g.add((sv, pv, ov))
|
||||||
continue
|
finally:
|
||||||
|
socket.close()
|
||||||
g.add((sv, pv, ov))
|
|
||||||
|
|
||||||
g.serialize(destination="output.ttl", format="turtle")
|
|
||||||
|
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
|
|
||||||
g.serialize(destination=buf, format="turtle")
|
g.serialize(destination=buf, format="turtle")
|
||||||
|
|
||||||
sys.stdout.write(buf.getvalue().decode("utf-8"))
|
sys.stdout.write(buf.getvalue().decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -103,6 +120,26 @@ def main():
|
||||||
help=f'Collection ID (default: {default_collection})'
|
help=f'Collection ID (default: {default_collection})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--token',
|
||||||
|
default=default_token,
|
||||||
|
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-l', '--limit',
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help='Maximum number of triples to return (default: 10000)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-b', '--batch-size',
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help='Triples per streaming batch (default: 20)',
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -112,6 +149,9 @@ def main():
|
||||||
flow_id = args.flow_id,
|
flow_id = args.flow_id,
|
||||||
user = args.user,
|
user = args.user,
|
||||||
collection = args.collection,
|
collection = args.collection,
|
||||||
|
limit = args.limit,
|
||||||
|
batch_size = args.batch_size,
|
||||||
|
token = args.token,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Connects to the graph query service and dumps all graph edges.
|
Connects to the graph query service and dumps all graph edges.
|
||||||
|
Uses streaming mode for lower time-to-first-result and reduced memory overhead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -11,17 +12,30 @@ default_user = 'trustgraph'
|
||||||
default_collection = 'default'
|
default_collection = 'default'
|
||||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
def show_graph(url, flow_id, user, collection, token=None):
|
def show_graph(url, flow_id, user, collection, limit, batch_size, token=None):
|
||||||
|
|
||||||
api = Api(url, token=token).flow().id(flow_id)
|
socket = Api(url, token=token).socket()
|
||||||
|
flow = socket.flow(flow_id)
|
||||||
|
|
||||||
rows = api.triples_query(
|
try:
|
||||||
user=user, collection=collection,
|
for batch in flow.triples_query_stream(
|
||||||
s=None, p=None, o=None, limit=10_000,
|
user=user,
|
||||||
)
|
collection=collection,
|
||||||
|
s=None, p=None, o=None,
|
||||||
for row in rows:
|
limit=limit,
|
||||||
print(row.s, row.p, row.o)
|
batch_size=batch_size,
|
||||||
|
):
|
||||||
|
for triple in batch:
|
||||||
|
s = triple.get("s", {})
|
||||||
|
p = triple.get("p", {})
|
||||||
|
o = triple.get("o", {})
|
||||||
|
# Format terms for display
|
||||||
|
s_str = s.get("v", s.get("i", str(s)))
|
||||||
|
p_str = p.get("v", p.get("i", str(p)))
|
||||||
|
o_str = o.get("v", o.get("i", str(o)))
|
||||||
|
print(s_str, p_str, o_str)
|
||||||
|
finally:
|
||||||
|
socket.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
|
|
@ -60,6 +74,20 @@ def main():
|
||||||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-l', '--limit',
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help='Maximum number of triples to return (default: 10000)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-b', '--batch-size',
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help='Triples per streaming batch (default: 20)',
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -69,6 +97,8 @@ def main():
|
||||||
flow_id = args.flow_id,
|
flow_id = args.flow_id,
|
||||||
user = args.user,
|
user = args.user,
|
||||||
collection = args.collection,
|
collection = args.collection,
|
||||||
|
limit = args.limit,
|
||||||
|
batch_size = args.batch_size,
|
||||||
token = args.token,
|
token = args.token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ null. Output is a list of quads.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from cassandra.query import SimpleStatement
|
||||||
|
|
||||||
from .... direct.cassandra_kg import (
|
from .... direct.cassandra_kg import (
|
||||||
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
|
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
|
||||||
|
|
@ -144,28 +145,30 @@ class Processor(TriplesQueryService):
|
||||||
self.cassandra_password = password
|
self.cassandra_password = password
|
||||||
self.table = None
|
self.table = None
|
||||||
|
|
||||||
|
def ensure_connection(self, user):
|
||||||
|
"""Ensure we have a connection to the correct keyspace."""
|
||||||
|
if user != self.table:
|
||||||
|
KGClass = EntityCentricKnowledgeGraph
|
||||||
|
|
||||||
|
if self.cassandra_username and self.cassandra_password:
|
||||||
|
self.tg = KGClass(
|
||||||
|
hosts=self.cassandra_host,
|
||||||
|
keyspace=user,
|
||||||
|
username=self.cassandra_username,
|
||||||
|
password=self.cassandra_password
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.tg = KGClass(
|
||||||
|
hosts=self.cassandra_host,
|
||||||
|
keyspace=user,
|
||||||
|
)
|
||||||
|
self.table = user
|
||||||
|
|
||||||
async def query_triples(self, query):
|
async def query_triples(self, query):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
user = query.user
|
self.ensure_connection(query.user)
|
||||||
|
|
||||||
if user != self.table:
|
|
||||||
# Use factory function to select implementation
|
|
||||||
KGClass = EntityCentricKnowledgeGraph
|
|
||||||
|
|
||||||
if self.cassandra_username and self.cassandra_password:
|
|
||||||
self.tg = KGClass(
|
|
||||||
hosts=self.cassandra_host,
|
|
||||||
keyspace=query.user,
|
|
||||||
username=self.cassandra_username, password=self.cassandra_password
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.tg = KGClass(
|
|
||||||
hosts=self.cassandra_host,
|
|
||||||
keyspace=query.user,
|
|
||||||
)
|
|
||||||
self.table = user
|
|
||||||
|
|
||||||
# Extract values from query
|
# Extract values from query
|
||||||
s_val = get_term_value(query.s)
|
s_val = get_term_value(query.s)
|
||||||
|
|
@ -291,6 +294,93 @@ class Processor(TriplesQueryService):
|
||||||
logger.error(f"Exception querying triples: {e}", exc_info=True)
|
logger.error(f"Exception querying triples: {e}", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def query_triples_stream(self, query):
|
||||||
|
"""
|
||||||
|
Streaming query - yields (batch, is_final) tuples.
|
||||||
|
Uses Cassandra's paging to fetch results incrementally.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.ensure_connection(query.user)
|
||||||
|
|
||||||
|
batch_size = query.batch_size if query.batch_size > 0 else 20
|
||||||
|
limit = query.limit if query.limit > 0 else 10000
|
||||||
|
|
||||||
|
# Extract query pattern
|
||||||
|
s_val = get_term_value(query.s)
|
||||||
|
p_val = get_term_value(query.p)
|
||||||
|
o_val = get_term_value(query.o)
|
||||||
|
g_val = query.g
|
||||||
|
|
||||||
|
# Helper to extract object metadata from result row
|
||||||
|
def get_o_metadata(t):
|
||||||
|
otype = getattr(t, 'otype', None)
|
||||||
|
dtype = getattr(t, 'dtype', None)
|
||||||
|
lang = getattr(t, 'lang', None)
|
||||||
|
return otype, dtype, lang
|
||||||
|
|
||||||
|
# For streaming, we need to execute with fetch_size
|
||||||
|
# Use the collection table for get_all queries (most common streaming case)
|
||||||
|
|
||||||
|
# Determine which query to use based on pattern
|
||||||
|
if s_val is None and p_val is None and o_val is None:
|
||||||
|
# Get all - use collection table with paging
|
||||||
|
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
|
||||||
|
params = [query.collection]
|
||||||
|
else:
|
||||||
|
# For specific patterns, fall back to non-streaming
|
||||||
|
# (these typically return small result sets anyway)
|
||||||
|
async for batch, is_final in self._fallback_stream(query, batch_size):
|
||||||
|
yield batch, is_final
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create statement with fetch_size for true streaming
|
||||||
|
statement = SimpleStatement(cql, fetch_size=batch_size)
|
||||||
|
result_set = self.tg.session.execute(statement, params)
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for row in result_set:
|
||||||
|
if count >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||||
|
otype, dtype, lang = get_o_metadata(row)
|
||||||
|
|
||||||
|
triple = Triple(
|
||||||
|
s=create_term(row.s),
|
||||||
|
p=create_term(row.p),
|
||||||
|
o=create_term(row.o, otype=otype, dtype=dtype, lang=lang),
|
||||||
|
g=g if g != DEFAULT_GRAPH else None
|
||||||
|
)
|
||||||
|
batch.append(triple)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# Yield batch when full (never mark as final mid-stream)
|
||||||
|
if len(batch) >= batch_size:
|
||||||
|
yield batch, False
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
# Always yield final batch to signal completion
|
||||||
|
# This handles: remaining rows, empty result set, or exact batch boundary
|
||||||
|
yield batch, True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception in streaming query: {e}", exc_info=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _fallback_stream(self, query, batch_size):
|
||||||
|
"""Fallback to non-streaming query with post-hoc batching."""
|
||||||
|
triples = await self.query_triples(query)
|
||||||
|
|
||||||
|
for i in range(0, len(triples), batch_size):
|
||||||
|
batch = triples[i:i + batch_size]
|
||||||
|
is_final = (i + batch_size >= len(triples))
|
||||||
|
yield batch, is_final
|
||||||
|
|
||||||
|
if len(triples) == 0:
|
||||||
|
yield [], True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,25 @@ import logging
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from ... schema import IRI, LITERAL
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||||
|
|
||||||
|
|
||||||
|
def term_to_string(term):
|
||||||
|
"""Extract string value from a Term object."""
|
||||||
|
if term is None:
|
||||||
|
return None
|
||||||
|
if term.type == IRI:
|
||||||
|
return term.iri
|
||||||
|
elif term.type == LITERAL:
|
||||||
|
return term.value
|
||||||
|
# Fallback
|
||||||
|
return term.iri or term.value or str(term)
|
||||||
|
|
||||||
class LRUCacheWithTTL:
|
class LRUCacheWithTTL:
|
||||||
"""LRU cache with TTL for label caching
|
"""LRU cache with TTL for label caching
|
||||||
|
|
||||||
|
|
@ -93,7 +107,7 @@ class Query:
|
||||||
)
|
)
|
||||||
|
|
||||||
entities = [
|
entities = [
|
||||||
str(e.entity)
|
term_to_string(e.entity)
|
||||||
for e in entity_matches
|
for e in entity_matches
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -129,26 +143,29 @@ class Query:
|
||||||
return label
|
return label
|
||||||
|
|
||||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||||
"""Execute triple queries for multiple entities concurrently"""
|
"""Execute triple queries for multiple entities concurrently using streaming"""
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
# Create concurrent tasks for all 3 query types per entity
|
# Create concurrent streaming tasks for all 3 query types per entity
|
||||||
tasks.extend([
|
tasks.extend([
|
||||||
self.rag.triples_client.query(
|
self.rag.triples_client.query_stream(
|
||||||
s=entity, p=None, o=None,
|
s=entity, p=None, o=None,
|
||||||
limit=limit_per_entity,
|
limit=limit_per_entity,
|
||||||
user=self.user, collection=self.collection
|
user=self.user, collection=self.collection,
|
||||||
|
batch_size=20,
|
||||||
),
|
),
|
||||||
self.rag.triples_client.query(
|
self.rag.triples_client.query_stream(
|
||||||
s=None, p=entity, o=None,
|
s=None, p=entity, o=None,
|
||||||
limit=limit_per_entity,
|
limit=limit_per_entity,
|
||||||
user=self.user, collection=self.collection
|
user=self.user, collection=self.collection,
|
||||||
|
batch_size=20,
|
||||||
),
|
),
|
||||||
self.rag.triples_client.query(
|
self.rag.triples_client.query_stream(
|
||||||
s=None, p=None, o=entity,
|
s=None, p=None, o=entity,
|
||||||
limit=limit_per_entity,
|
limit=limit_per_entity,
|
||||||
user=self.user, collection=self.collection
|
user=self.user, collection=self.collection,
|
||||||
|
batch_size=20,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
@ -158,7 +175,7 @@ class Query:
|
||||||
# Combine all results
|
# Combine all results
|
||||||
all_triples = []
|
all_triples = []
|
||||||
for result in results:
|
for result in results:
|
||||||
if not isinstance(result, Exception):
|
if not isinstance(result, Exception) and result is not None:
|
||||||
all_triples.extend(result)
|
all_triples.extend(result)
|
||||||
|
|
||||||
return all_triples
|
return all_triples
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue