mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
Feature/streaming triples (#676)
* Steaming triples * Also GraphRAG service uses this * Updated tests
This commit is contained in:
parent
3c3e11bef5
commit
d2d71f859d
11 changed files with 542 additions and 116 deletions
|
|
@ -22,11 +22,19 @@ def to_value(x):
|
|||
|
||||
|
||||
def from_value(x):
|
||||
"""Convert Uri or Literal to schema Term."""
|
||||
"""Convert Uri, Literal, or string to schema Term."""
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, Uri):
|
||||
return Term(type=IRI, iri=str(x))
|
||||
elif isinstance(x, Literal):
|
||||
return Term(type=LITERAL, value=str(x))
|
||||
elif isinstance(x, str):
|
||||
# Detect IRIs by common prefixes
|
||||
if x.startswith("http://") or x.startswith("https://") or x.startswith("urn:"):
|
||||
return Term(type=IRI, iri=x)
|
||||
else:
|
||||
return Term(type=LITERAL, value=x)
|
||||
else:
|
||||
return Term(type=LITERAL, value=str(x))
|
||||
|
||||
|
|
@ -57,6 +65,61 @@ class TriplesClient(RequestResponse):
|
|||
|
||||
return triples
|
||||
|
||||
async def query_stream(self, s=None, p=None, o=None, limit=20,
|
||||
user="trustgraph", collection="default",
|
||||
batch_size=20, timeout=30,
|
||||
batch_callback=None):
|
||||
"""
|
||||
Streaming triple query - calls callback for each batch as it arrives.
|
||||
|
||||
Args:
|
||||
s, p, o: Triple pattern (None for wildcard)
|
||||
limit: Maximum total triples to return
|
||||
user: User/keyspace
|
||||
collection: Collection name
|
||||
batch_size: Triples per batch
|
||||
timeout: Request timeout in seconds
|
||||
batch_callback: Async callback(batch, is_final) called for each batch
|
||||
|
||||
Returns:
|
||||
List[Triple]: All triples (flattened) if no callback provided
|
||||
"""
|
||||
all_triples = []
|
||||
|
||||
async def recipient(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
batch = [
|
||||
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
|
||||
for v in resp.triples
|
||||
]
|
||||
|
||||
if batch_callback:
|
||||
await batch_callback(batch, resp.is_final)
|
||||
else:
|
||||
all_triples.extend(batch)
|
||||
|
||||
return resp.is_final
|
||||
|
||||
await self.request(
|
||||
TriplesQueryRequest(
|
||||
s=from_value(s),
|
||||
p=from_value(p),
|
||||
o=from_value(o),
|
||||
limit=limit,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=True,
|
||||
batch_size=batch_size,
|
||||
),
|
||||
timeout=timeout,
|
||||
recipient=recipient,
|
||||
)
|
||||
|
||||
if not batch_callback:
|
||||
return all_triples
|
||||
|
||||
class TriplesClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
|
|
|
|||
|
|
@ -52,13 +52,23 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
logger.debug(f"Handling triples query request {id}...")
|
||||
|
||||
triples = await self.query_triples(request)
|
||||
|
||||
logger.debug("Sending triples query response...")
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
logger.debug("Triples query request completed")
|
||||
if request.streaming:
|
||||
# Streaming mode: send batches
|
||||
async for batch, is_final in self.query_triples_stream(request):
|
||||
r = TriplesQueryResponse(
|
||||
triples=batch,
|
||||
error=None,
|
||||
is_final=is_final,
|
||||
)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
logger.debug("Triples query streaming completed")
|
||||
else:
|
||||
# Non-streaming mode: single response
|
||||
triples = await self.query_triples(request)
|
||||
logger.debug("Sending triples query response...")
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
logger.debug("Triples query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
@ -76,6 +86,24 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
async def query_triples_stream(self, request):
|
||||
"""
|
||||
Streaming query - yields (batch, is_final) tuples.
|
||||
Default implementation batches results from query_triples.
|
||||
Override for true streaming from backend.
|
||||
"""
|
||||
triples = await self.query_triples(request)
|
||||
batch_size = request.batch_size if request.batch_size > 0 else 20
|
||||
|
||||
for i in range(0, len(triples), batch_size):
|
||||
batch = triples[i:i + batch_size]
|
||||
is_final = (i + batch_size >= len(triples))
|
||||
yield batch, is_final
|
||||
|
||||
# Handle empty result
|
||||
if len(triples) == 0:
|
||||
yield [], True
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue