mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Feature/streaming triples (#676)
* Steaming triples * Also GraphRAG service uses this * Updated tests
This commit is contained in:
parent
3c3e11bef5
commit
d2d71f859d
11 changed files with 542 additions and 116 deletions
|
|
@ -48,7 +48,7 @@ class TestGraphRagIntegration:
|
|||
client = AsyncMock()
|
||||
|
||||
# Mock different queries return different triples
|
||||
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
|
||||
async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20):
|
||||
# Mock label queries
|
||||
if p == "http://www.w3.org/2000/01/rdf-schema#label":
|
||||
if s == "http://trustgraph.ai/e/machine-learning":
|
||||
|
|
@ -76,7 +76,9 @@ class TestGraphRagIntegration:
|
|||
|
||||
return []
|
||||
|
||||
client.query.side_effect = query_side_effect
|
||||
client.query_stream.side_effect = query_stream_side_effect
|
||||
# Also mock query for label lookups (maybe_label uses query, not query_stream)
|
||||
client.query.side_effect = query_stream_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -137,7 +139,7 @@ class TestGraphRagIntegration:
|
|||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query.call_count > 0
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt with knowledge graph
|
||||
mock_prompt_client.kg_prompt.assert_called_once()
|
||||
|
|
@ -202,7 +204,7 @@ class TestGraphRagIntegration:
|
|||
"""Test GraphRAG handles empty knowledge graph gracefully"""
|
||||
# Arrange
|
||||
mock_graph_embeddings_client.query.return_value = [] # No entities found
|
||||
mock_triples_client.query.return_value = [] # No triples found
|
||||
mock_triples_client.query_stream.return_value = [] # No triples found
|
||||
|
||||
# Act
|
||||
result = await graph_rag.query(
|
||||
|
|
@ -231,7 +233,7 @@ class TestGraphRagIntegration:
|
|||
collection="test_collection"
|
||||
)
|
||||
|
||||
first_call_count = mock_triples_client.query.call_count
|
||||
first_call_count = mock_triples_client.query_stream.call_count
|
||||
mock_triples_client.reset_mock()
|
||||
|
||||
# Second identical query
|
||||
|
|
@ -241,7 +243,7 @@ class TestGraphRagIntegration:
|
|||
collection="test_collection"
|
||||
)
|
||||
|
||||
second_call_count = mock_triples_client.query.call_count
|
||||
second_call_count = mock_triples_client.query_stream.call_count
|
||||
|
||||
# Assert - Second query should make fewer triple queries due to caching
|
||||
# Note: This is a weak assertion because caching behavior depends on
|
||||
|
|
|
|||
|
|
@ -193,15 +193,17 @@ class TestQuery:
|
|||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
|
||||
# Mock EntityMatch objects with entity that has string representation
|
||||
# Mock EntityMatch objects with entity as Term-like object
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
||||
mock_entity1.type = "i" # IRI type
|
||||
mock_entity1.iri = "entity1"
|
||||
mock_match1 = MagicMock()
|
||||
mock_match1.entity = mock_entity1
|
||||
mock_match1.score = 0.95
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
||||
mock_entity2.type = "i" # IRI type
|
||||
mock_entity2.iri = "entity2"
|
||||
mock_match2 = MagicMock()
|
||||
mock_match2.entity = mock_entity2
|
||||
mock_match2.score = 0.85
|
||||
|
|
@ -363,10 +365,10 @@ class TestQuery:
|
|||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
# Setup query responses for s=ent, p=ent, o=ent patterns
|
||||
mock_triples_client.query.side_effect = [
|
||||
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
|
||||
mock_triples_client.query_stream.side_effect = [
|
||||
[mock_triple1], # s=ent, p=None, o=None
|
||||
[mock_triple2], # s=None, p=ent, o=None
|
||||
[mock_triple2], # s=None, p=ent, o=None
|
||||
[mock_triple3], # s=None, p=None, o=ent
|
||||
]
|
||||
|
||||
|
|
@ -384,20 +386,20 @@ class TestQuery:
|
|||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify all three query patterns were called
|
||||
assert mock_triples_client.query.call_count == 3
|
||||
|
||||
# Verify query calls
|
||||
mock_triples_client.query.assert_any_call(
|
||||
assert mock_triples_client.query_stream.call_count == 3
|
||||
|
||||
# Verify query_stream calls
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
)
|
||||
mock_triples_client.query.assert_any_call(
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
)
|
||||
mock_triples_client.query.assert_any_call(
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
)
|
||||
|
||||
# Verify subgraph contains discovered triples
|
||||
|
|
@ -427,9 +429,9 @@ class TestQuery:
|
|||
# Call follow_edges with path_length=0
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
|
||||
# Verify no queries were made
|
||||
mock_triples_client.query.assert_not_called()
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
|
||||
# Verify subgraph remains empty
|
||||
assert subgraph == set()
|
||||
|
|
@ -456,9 +458,9 @@ class TestQuery:
|
|||
|
||||
# Call follow_edges
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
|
||||
# Verify no queries were made due to size limit
|
||||
mock_triples_client.query.assert_not_called()
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
|
||||
# Verify subgraph unchanged
|
||||
assert len(subgraph) == 3
|
||||
|
|
|
|||
|
|
@ -91,9 +91,18 @@ class SocketClient:
|
|||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any],
|
||||
streaming: bool = False
|
||||
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
|
||||
"""Synchronous wrapper around async WebSocket communication"""
|
||||
streaming: bool = False,
|
||||
streaming_raw: bool = False
|
||||
) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]:
|
||||
"""Synchronous wrapper around async WebSocket communication.
|
||||
|
||||
Args:
|
||||
service: Service name
|
||||
flow: Flow ID (optional)
|
||||
request: Request payload
|
||||
streaming: Use parsed streaming (for agent/RAG chunk types)
|
||||
streaming_raw: Use raw streaming (for data batches like triples)
|
||||
"""
|
||||
# Create event loop if needed
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -105,12 +114,14 @@ class SocketClient:
|
|||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if streaming:
|
||||
# For streaming, we need to return an iterator
|
||||
# Create a generator that runs async code
|
||||
if streaming_raw:
|
||||
# Raw streaming for data batches (triples, rows, etc.)
|
||||
return self._streaming_generator_raw(service, flow, request, loop)
|
||||
elif streaming:
|
||||
# Parsed streaming for agent/RAG chunk types
|
||||
return self._streaming_generator(service, flow, request, loop)
|
||||
else:
|
||||
# For non-streaming, just run the async code and return result
|
||||
# Non-streaming single response
|
||||
return loop.run_until_complete(self._send_request_async(service, flow, request))
|
||||
|
||||
def _streaming_generator(
|
||||
|
|
@ -120,7 +131,7 @@ class SocketClient:
|
|||
request: Dict[str, Any],
|
||||
loop: asyncio.AbstractEventLoop
|
||||
) -> Iterator[StreamingChunk]:
|
||||
"""Generator that yields streaming chunks"""
|
||||
"""Generator that yields streaming chunks (for agent/RAG responses)"""
|
||||
async_gen = self._send_request_async_streaming(service, flow, request)
|
||||
|
||||
try:
|
||||
|
|
@ -137,6 +148,74 @@ class SocketClient:
|
|||
except:
|
||||
pass
|
||||
|
||||
def _streaming_generator_raw(
|
||||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any],
|
||||
loop: asyncio.AbstractEventLoop
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""Generator that yields raw response dicts (for data streaming like triples)"""
|
||||
async_gen = self._send_request_async_streaming_raw(service, flow, request)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
data = loop.run_until_complete(async_gen.__anext__())
|
||||
yield data
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _send_request_async_streaming_raw(
|
||||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any]
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""Async streaming that yields raw response dicts without parsing.
|
||||
|
||||
Used for data streaming (triples, rows, etc.) where responses are
|
||||
just batches of data, not agent/RAG chunk types.
|
||||
"""
|
||||
with self._lock:
|
||||
self._request_counter += 1
|
||||
request_id = f"req-{self._request_counter}"
|
||||
|
||||
ws_url = f"{self.url}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != request_id:
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
raise_from_error_dict(response["error"])
|
||||
|
||||
if "response" in response:
|
||||
yield response["response"]
|
||||
|
||||
if response.get("complete"):
|
||||
break
|
||||
|
||||
async def _send_request_async(
|
||||
self,
|
||||
service: str,
|
||||
|
|
@ -790,6 +869,74 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||
|
||||
def triples_query_stream(
|
||||
self,
|
||||
s: Optional[str] = None,
|
||||
p: Optional[str] = None,
|
||||
o: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
batch_size: int = 20,
|
||||
**kwargs: Any
|
||||
) -> Iterator[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Query knowledge graph triples with streaming batches.
|
||||
|
||||
Yields batches of triples as they arrive, reducing time-to-first-result
|
||||
and memory overhead for large result sets.
|
||||
|
||||
Args:
|
||||
s: Subject URI (optional, use None for wildcard)
|
||||
p: Predicate URI (optional, use None for wildcard)
|
||||
o: Object URI or Literal (optional, use None for wildcard)
|
||||
user: User/keyspace identifier (optional)
|
||||
collection: Collection identifier (optional)
|
||||
limit: Maximum results to return (default: 100)
|
||||
batch_size: Triples per batch (default: 20)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Yields:
|
||||
List[Dict]: Batches of triples in wire format
|
||||
|
||||
Example:
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
for batch in flow.triples_query_stream(
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
):
|
||||
for triple in batch:
|
||||
print(triple["s"], triple["p"], triple["o"])
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"limit": limit,
|
||||
"streaming": True,
|
||||
"batch-size": batch_size,
|
||||
}
|
||||
if s is not None:
|
||||
request["s"] = str(s)
|
||||
if p is not None:
|
||||
request["p"] = str(p)
|
||||
if o is not None:
|
||||
request["o"] = str(o)
|
||||
if user is not None:
|
||||
request["user"] = user
|
||||
if collection is not None:
|
||||
request["collection"] = collection
|
||||
request.update(kwargs)
|
||||
|
||||
# Use raw streaming - yields response dicts directly without parsing
|
||||
for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True):
|
||||
# Response is {"response": [...triples...]} from translator
|
||||
if isinstance(response, dict) and "response" in response:
|
||||
yield response["response"]
|
||||
else:
|
||||
yield response
|
||||
|
||||
def rows_query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
|
|||
|
|
@ -22,11 +22,19 @@ def to_value(x):
|
|||
|
||||
|
||||
def from_value(x):
|
||||
"""Convert Uri or Literal to schema Term."""
|
||||
"""Convert Uri, Literal, or string to schema Term."""
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, Uri):
|
||||
return Term(type=IRI, iri=str(x))
|
||||
elif isinstance(x, Literal):
|
||||
return Term(type=LITERAL, value=str(x))
|
||||
elif isinstance(x, str):
|
||||
# Detect IRIs by common prefixes
|
||||
if x.startswith("http://") or x.startswith("https://") or x.startswith("urn:"):
|
||||
return Term(type=IRI, iri=x)
|
||||
else:
|
||||
return Term(type=LITERAL, value=x)
|
||||
else:
|
||||
return Term(type=LITERAL, value=str(x))
|
||||
|
||||
|
|
@ -57,6 +65,61 @@ class TriplesClient(RequestResponse):
|
|||
|
||||
return triples
|
||||
|
||||
async def query_stream(self, s=None, p=None, o=None, limit=20,
|
||||
user="trustgraph", collection="default",
|
||||
batch_size=20, timeout=30,
|
||||
batch_callback=None):
|
||||
"""
|
||||
Streaming triple query - calls callback for each batch as it arrives.
|
||||
|
||||
Args:
|
||||
s, p, o: Triple pattern (None for wildcard)
|
||||
limit: Maximum total triples to return
|
||||
user: User/keyspace
|
||||
collection: Collection name
|
||||
batch_size: Triples per batch
|
||||
timeout: Request timeout in seconds
|
||||
batch_callback: Async callback(batch, is_final) called for each batch
|
||||
|
||||
Returns:
|
||||
List[Triple]: All triples (flattened) if no callback provided
|
||||
"""
|
||||
all_triples = []
|
||||
|
||||
async def recipient(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
batch = [
|
||||
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
|
||||
for v in resp.triples
|
||||
]
|
||||
|
||||
if batch_callback:
|
||||
await batch_callback(batch, resp.is_final)
|
||||
else:
|
||||
all_triples.extend(batch)
|
||||
|
||||
return resp.is_final
|
||||
|
||||
await self.request(
|
||||
TriplesQueryRequest(
|
||||
s=from_value(s),
|
||||
p=from_value(p),
|
||||
o=from_value(o),
|
||||
limit=limit,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=True,
|
||||
batch_size=batch_size,
|
||||
),
|
||||
timeout=timeout,
|
||||
recipient=recipient,
|
||||
)
|
||||
|
||||
if not batch_callback:
|
||||
return all_triples
|
||||
|
||||
class TriplesClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
|
|
|
|||
|
|
@ -52,13 +52,23 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
logger.debug(f"Handling triples query request {id}...")
|
||||
|
||||
triples = await self.query_triples(request)
|
||||
|
||||
logger.debug("Sending triples query response...")
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
logger.debug("Triples query request completed")
|
||||
if request.streaming:
|
||||
# Streaming mode: send batches
|
||||
async for batch, is_final in self.query_triples_stream(request):
|
||||
r = TriplesQueryResponse(
|
||||
triples=batch,
|
||||
error=None,
|
||||
is_final=is_final,
|
||||
)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
logger.debug("Triples query streaming completed")
|
||||
else:
|
||||
# Non-streaming mode: single response
|
||||
triples = await self.query_triples(request)
|
||||
logger.debug("Sending triples query response...")
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
logger.debug("Triples query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
@ -76,6 +86,24 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
async def query_triples_stream(self, request):
|
||||
"""
|
||||
Streaming query - yields (batch, is_final) tuples.
|
||||
Default implementation batches results from query_triples.
|
||||
Override for true streaming from backend.
|
||||
"""
|
||||
triples = await self.query_triples(request)
|
||||
batch_size = request.batch_size if request.batch_size > 0 else 20
|
||||
|
||||
for i in range(0, len(triples), batch_size):
|
||||
batch = triples[i:i + batch_size]
|
||||
is_final = (i + batch_size >= len(triples))
|
||||
yield batch, is_final
|
||||
|
||||
# Handle empty result
|
||||
if len(triples) == 0:
|
||||
yield [], True
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -23,14 +23,18 @@ class TriplesQueryRequestTranslator(MessageTranslator):
|
|||
g=g,
|
||||
limit=int(data.get("limit", 10000)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
collection=data.get("collection", "default"),
|
||||
streaming=data.get("streaming", False),
|
||||
batch_size=int(data.get("batch-size", 20)),
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
"collection": obj.collection,
|
||||
"streaming": obj.streaming,
|
||||
"batch-size": obj.batch_size,
|
||||
}
|
||||
|
||||
if obj.s:
|
||||
|
|
@ -61,4 +65,4 @@ class TriplesQueryResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
return self.from_pulsar(obj), obj.is_final
|
||||
|
|
@ -38,11 +38,14 @@ class TriplesQueryRequest:
|
|||
o: Term | None = None
|
||||
g: str | None = None # Graph IRI. None=default graph, "*"=all graphs
|
||||
limit: int = 0
|
||||
streaming: bool = False # Enable streaming mode (multiple batched responses)
|
||||
batch_size: int = 20 # Triples per batch in streaming mode
|
||||
|
||||
@dataclass
|
||||
class TriplesQueryResponse:
|
||||
error: Error | None = None
|
||||
triples: list[Triple] = field(default_factory=list)
|
||||
is_final: bool = True # False for intermediate batches in streaming mode
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Connects to the graph query service and dumps all graph edges in Turtle
|
||||
format with RDF-star support for quoted triples.
|
||||
Uses streaming mode for lower time-to-first-processing.
|
||||
"""
|
||||
|
||||
import rdflib
|
||||
|
|
@ -9,66 +10,82 @@ import sys
|
|||
import argparse
|
||||
import os
|
||||
|
||||
from trustgraph.api import Api, Uri
|
||||
from trustgraph.knowledge import QuotedTriple
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
|
||||
def value_to_rdflib(val):
|
||||
"""Convert a TrustGraph value to an rdflib term."""
|
||||
if isinstance(val, Uri):
|
||||
def term_to_rdflib(term):
|
||||
"""Convert a wire-format term to an rdflib term."""
|
||||
if term is None:
|
||||
return None
|
||||
|
||||
t = term.get("t", "")
|
||||
|
||||
if t == "i": # IRI
|
||||
iri = term.get("i", "")
|
||||
# Skip malformed URLs with spaces
|
||||
if " " in val:
|
||||
if " " in iri:
|
||||
return None
|
||||
return rdflib.term.URIRef(val)
|
||||
elif isinstance(val, QuotedTriple):
|
||||
# RDF-star quoted triple
|
||||
s_term = value_to_rdflib(val.s)
|
||||
p_term = value_to_rdflib(val.p)
|
||||
o_term = value_to_rdflib(val.o)
|
||||
return rdflib.term.URIRef(iri)
|
||||
elif t == "l": # Literal
|
||||
value = term.get("v", "")
|
||||
datatype = term.get("d")
|
||||
language = term.get("l")
|
||||
if language:
|
||||
return rdflib.term.Literal(value, lang=language)
|
||||
elif datatype:
|
||||
return rdflib.term.Literal(value, datatype=rdflib.term.URIRef(datatype))
|
||||
else:
|
||||
return rdflib.term.Literal(value)
|
||||
elif t == "r": # Quoted triple (RDF-star)
|
||||
triple = term.get("r", {})
|
||||
s_term = term_to_rdflib(triple.get("s"))
|
||||
p_term = term_to_rdflib(triple.get("p"))
|
||||
o_term = term_to_rdflib(triple.get("o"))
|
||||
if s_term is None or p_term is None or o_term is None:
|
||||
return None
|
||||
# rdflib 6.x+ supports Triple as a term type
|
||||
try:
|
||||
return rdflib.term.Triple((s_term, p_term, o_term))
|
||||
except AttributeError:
|
||||
# Fallback for older rdflib versions - represent as string
|
||||
return rdflib.term.Literal(f"<<{val.s} {val.p} {val.o}>>")
|
||||
# Fallback for older rdflib versions
|
||||
return rdflib.term.Literal(f"<<{s_term} {p_term} {o_term}>>")
|
||||
else:
|
||||
return rdflib.term.Literal(str(val))
|
||||
# Fallback
|
||||
return rdflib.term.Literal(str(term))
|
||||
|
||||
|
||||
def show_graph(url, flow_id, user, collection):
|
||||
def show_graph(url, flow_id, user, collection, limit, batch_size, token=None):
|
||||
|
||||
api = Api(url).flow().id(flow_id)
|
||||
|
||||
rows = api.triples_query(
|
||||
s=None, p=None, o=None,
|
||||
user=user, collection=collection,
|
||||
limit=10_000)
|
||||
socket = Api(url, token=token).socket()
|
||||
flow = socket.flow(flow_id)
|
||||
|
||||
g = rdflib.Graph()
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
for batch in flow.triples_query_stream(
|
||||
s=None, p=None, o=None,
|
||||
user=user, collection=collection,
|
||||
limit=limit,
|
||||
batch_size=batch_size,
|
||||
):
|
||||
for triple in batch:
|
||||
sv = term_to_rdflib(triple.get("s"))
|
||||
pv = term_to_rdflib(triple.get("p"))
|
||||
ov = term_to_rdflib(triple.get("o"))
|
||||
|
||||
sv = rdflib.term.URIRef(row.s)
|
||||
pv = rdflib.term.URIRef(row.p)
|
||||
ov = value_to_rdflib(row.o)
|
||||
if sv is None or pv is None or ov is None:
|
||||
continue
|
||||
|
||||
if ov is None:
|
||||
continue
|
||||
|
||||
g.add((sv, pv, ov))
|
||||
|
||||
g.serialize(destination="output.ttl", format="turtle")
|
||||
g.add((sv, pv, ov))
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
buf = io.BytesIO()
|
||||
|
||||
g.serialize(destination=buf, format="turtle")
|
||||
|
||||
sys.stdout.write(buf.getvalue().decode("utf-8"))
|
||||
|
||||
|
||||
|
|
@ -103,6 +120,26 @@ def main():
|
|||
help=f'Collection ID (default: {default_collection})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--limit',
|
||||
type=int,
|
||||
default=10000,
|
||||
help='Maximum number of triples to return (default: 10000)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-b', '--batch-size',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Triples per streaming batch (default: 20)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -112,6 +149,9 @@ def main():
|
|||
flow_id = args.flow_id,
|
||||
user = args.user,
|
||||
collection = args.collection,
|
||||
limit = args.limit,
|
||||
batch_size = args.batch_size,
|
||||
token = args.token,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""
|
||||
Connects to the graph query service and dumps all graph edges.
|
||||
Uses streaming mode for lower time-to-first-result and reduced memory overhead.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -11,17 +12,30 @@ default_user = 'trustgraph'
|
|||
default_collection = 'default'
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
def show_graph(url, flow_id, user, collection, token=None):
|
||||
def show_graph(url, flow_id, user, collection, limit, batch_size, token=None):
|
||||
|
||||
api = Api(url, token=token).flow().id(flow_id)
|
||||
socket = Api(url, token=token).socket()
|
||||
flow = socket.flow(flow_id)
|
||||
|
||||
rows = api.triples_query(
|
||||
user=user, collection=collection,
|
||||
s=None, p=None, o=None, limit=10_000,
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
print(row.s, row.p, row.o)
|
||||
try:
|
||||
for batch in flow.triples_query_stream(
|
||||
user=user,
|
||||
collection=collection,
|
||||
s=None, p=None, o=None,
|
||||
limit=limit,
|
||||
batch_size=batch_size,
|
||||
):
|
||||
for triple in batch:
|
||||
s = triple.get("s", {})
|
||||
p = triple.get("p", {})
|
||||
o = triple.get("o", {})
|
||||
# Format terms for display
|
||||
s_str = s.get("v", s.get("i", str(s)))
|
||||
p_str = p.get("v", p.get("i", str(p)))
|
||||
o_str = o.get("v", o.get("i", str(o)))
|
||||
print(s_str, p_str, o_str)
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
|
|
@ -60,6 +74,20 @@ def main():
|
|||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--limit',
|
||||
type=int,
|
||||
default=10000,
|
||||
help='Maximum number of triples to return (default: 10000)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-b', '--batch-size',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Triples per streaming batch (default: 20)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -69,6 +97,8 @@ def main():
|
|||
flow_id = args.flow_id,
|
||||
user = args.user,
|
||||
collection = args.collection,
|
||||
limit = args.limit,
|
||||
batch_size = args.batch_size,
|
||||
token = args.token,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ null. Output is a list of quads.
|
|||
import logging
|
||||
|
||||
import json
|
||||
from cassandra.query import SimpleStatement
|
||||
|
||||
from .... direct.cassandra_kg import (
|
||||
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
|
||||
|
|
@ -144,28 +145,30 @@ class Processor(TriplesQueryService):
|
|||
self.cassandra_password = password
|
||||
self.table = None
|
||||
|
||||
def ensure_connection(self, user):
|
||||
"""Ensure we have a connection to the correct keyspace."""
|
||||
if user != self.table:
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
)
|
||||
self.table = user
|
||||
|
||||
async def query_triples(self, query):
|
||||
|
||||
try:
|
||||
|
||||
user = query.user
|
||||
|
||||
if user != self.table:
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
username=self.cassandra_username, password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
)
|
||||
self.table = user
|
||||
self.ensure_connection(query.user)
|
||||
|
||||
# Extract values from query
|
||||
s_val = get_term_value(query.s)
|
||||
|
|
@ -291,6 +294,93 @@ class Processor(TriplesQueryService):
|
|||
logger.error(f"Exception querying triples: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def query_triples_stream(self, query):
|
||||
"""
|
||||
Streaming query - yields (batch, is_final) tuples.
|
||||
Uses Cassandra's paging to fetch results incrementally.
|
||||
"""
|
||||
try:
|
||||
self.ensure_connection(query.user)
|
||||
|
||||
batch_size = query.batch_size if query.batch_size > 0 else 20
|
||||
limit = query.limit if query.limit > 0 else 10000
|
||||
|
||||
# Extract query pattern
|
||||
s_val = get_term_value(query.s)
|
||||
p_val = get_term_value(query.p)
|
||||
o_val = get_term_value(query.o)
|
||||
g_val = query.g
|
||||
|
||||
# Helper to extract object metadata from result row
|
||||
def get_o_metadata(t):
|
||||
otype = getattr(t, 'otype', None)
|
||||
dtype = getattr(t, 'dtype', None)
|
||||
lang = getattr(t, 'lang', None)
|
||||
return otype, dtype, lang
|
||||
|
||||
# For streaming, we need to execute with fetch_size
|
||||
# Use the collection table for get_all queries (most common streaming case)
|
||||
|
||||
# Determine which query to use based on pattern
|
||||
if s_val is None and p_val is None and o_val is None:
|
||||
# Get all - use collection table with paging
|
||||
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
|
||||
params = [query.collection]
|
||||
else:
|
||||
# For specific patterns, fall back to non-streaming
|
||||
# (these typically return small result sets anyway)
|
||||
async for batch, is_final in self._fallback_stream(query, batch_size):
|
||||
yield batch, is_final
|
||||
return
|
||||
|
||||
# Create statement with fetch_size for true streaming
|
||||
statement = SimpleStatement(cql, fetch_size=batch_size)
|
||||
result_set = self.tg.session.execute(statement, params)
|
||||
|
||||
batch = []
|
||||
count = 0
|
||||
|
||||
for row in result_set:
|
||||
if count >= limit:
|
||||
break
|
||||
|
||||
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(row)
|
||||
|
||||
triple = Triple(
|
||||
s=create_term(row.s),
|
||||
p=create_term(row.p),
|
||||
o=create_term(row.o, otype=otype, dtype=dtype, lang=lang),
|
||||
g=g if g != DEFAULT_GRAPH else None
|
||||
)
|
||||
batch.append(triple)
|
||||
count += 1
|
||||
|
||||
# Yield batch when full (never mark as final mid-stream)
|
||||
if len(batch) >= batch_size:
|
||||
yield batch, False
|
||||
batch = []
|
||||
|
||||
# Always yield final batch to signal completion
|
||||
# This handles: remaining rows, empty result set, or exact batch boundary
|
||||
yield batch, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in streaming query: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def _fallback_stream(self, query, batch_size):
|
||||
"""Fallback to non-streaming query with post-hoc batching."""
|
||||
triples = await self.query_triples(query)
|
||||
|
||||
for i in range(0, len(triples), batch_size):
|
||||
batch = triples[i:i + batch_size]
|
||||
is_final = (i + batch_size >= len(triples))
|
||||
yield batch, is_final
|
||||
|
||||
if len(triples) == 0:
|
||||
yield [], True
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,25 @@ import logging
|
|||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
from ... schema import IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
|
||||
def term_to_string(term):
|
||||
"""Extract string value from a Term object."""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
# Fallback
|
||||
return term.iri or term.value or str(term)
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
|
|
@ -93,7 +107,7 @@ class Query:
|
|||
)
|
||||
|
||||
entities = [
|
||||
str(e.entity)
|
||||
term_to_string(e.entity)
|
||||
for e in entity_matches
|
||||
]
|
||||
|
||||
|
|
@ -129,26 +143,29 @@ class Query:
|
|||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently"""
|
||||
"""Execute triple queries for multiple entities concurrently using streaming"""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent tasks for all 3 query types per entity
|
||||
# Create concurrent streaming tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
)
|
||||
])
|
||||
|
||||
|
|
@ -158,7 +175,7 @@ class Query:
|
|||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
if not isinstance(result, Exception) and result is not None:
|
||||
all_triples.extend(result)
|
||||
|
||||
return all_triples
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue