mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Feature/streaming triples (#676)
* Steaming triples * Also GraphRAG service uses this * Updated tests
This commit is contained in:
parent
3c3e11bef5
commit
d2d71f859d
11 changed files with 542 additions and 116 deletions
|
|
@ -91,9 +91,18 @@ class SocketClient:
|
|||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any],
|
||||
streaming: bool = False
|
||||
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
|
||||
"""Synchronous wrapper around async WebSocket communication"""
|
||||
streaming: bool = False,
|
||||
streaming_raw: bool = False
|
||||
) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]:
|
||||
"""Synchronous wrapper around async WebSocket communication.
|
||||
|
||||
Args:
|
||||
service: Service name
|
||||
flow: Flow ID (optional)
|
||||
request: Request payload
|
||||
streaming: Use parsed streaming (for agent/RAG chunk types)
|
||||
streaming_raw: Use raw streaming (for data batches like triples)
|
||||
"""
|
||||
# Create event loop if needed
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -105,12 +114,14 @@ class SocketClient:
|
|||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if streaming:
|
||||
# For streaming, we need to return an iterator
|
||||
# Create a generator that runs async code
|
||||
if streaming_raw:
|
||||
# Raw streaming for data batches (triples, rows, etc.)
|
||||
return self._streaming_generator_raw(service, flow, request, loop)
|
||||
elif streaming:
|
||||
# Parsed streaming for agent/RAG chunk types
|
||||
return self._streaming_generator(service, flow, request, loop)
|
||||
else:
|
||||
# For non-streaming, just run the async code and return result
|
||||
# Non-streaming single response
|
||||
return loop.run_until_complete(self._send_request_async(service, flow, request))
|
||||
|
||||
def _streaming_generator(
|
||||
|
|
@ -120,7 +131,7 @@ class SocketClient:
|
|||
request: Dict[str, Any],
|
||||
loop: asyncio.AbstractEventLoop
|
||||
) -> Iterator[StreamingChunk]:
|
||||
"""Generator that yields streaming chunks"""
|
||||
"""Generator that yields streaming chunks (for agent/RAG responses)"""
|
||||
async_gen = self._send_request_async_streaming(service, flow, request)
|
||||
|
||||
try:
|
||||
|
|
@ -137,6 +148,74 @@ class SocketClient:
|
|||
except:
|
||||
pass
|
||||
|
||||
def _streaming_generator_raw(
|
||||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any],
|
||||
loop: asyncio.AbstractEventLoop
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""Generator that yields raw response dicts (for data streaming like triples)"""
|
||||
async_gen = self._send_request_async_streaming_raw(service, flow, request)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
data = loop.run_until_complete(async_gen.__anext__())
|
||||
yield data
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _send_request_async_streaming_raw(
|
||||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any]
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""Async streaming that yields raw response dicts without parsing.
|
||||
|
||||
Used for data streaming (triples, rows, etc.) where responses are
|
||||
just batches of data, not agent/RAG chunk types.
|
||||
"""
|
||||
with self._lock:
|
||||
self._request_counter += 1
|
||||
request_id = f"req-{self._request_counter}"
|
||||
|
||||
ws_url = f"{self.url}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != request_id:
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
raise_from_error_dict(response["error"])
|
||||
|
||||
if "response" in response:
|
||||
yield response["response"]
|
||||
|
||||
if response.get("complete"):
|
||||
break
|
||||
|
||||
async def _send_request_async(
|
||||
self,
|
||||
service: str,
|
||||
|
|
@ -790,6 +869,74 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||
|
||||
def triples_query_stream(
|
||||
self,
|
||||
s: Optional[str] = None,
|
||||
p: Optional[str] = None,
|
||||
o: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
batch_size: int = 20,
|
||||
**kwargs: Any
|
||||
) -> Iterator[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Query knowledge graph triples with streaming batches.
|
||||
|
||||
Yields batches of triples as they arrive, reducing time-to-first-result
|
||||
and memory overhead for large result sets.
|
||||
|
||||
Args:
|
||||
s: Subject URI (optional, use None for wildcard)
|
||||
p: Predicate URI (optional, use None for wildcard)
|
||||
o: Object URI or Literal (optional, use None for wildcard)
|
||||
user: User/keyspace identifier (optional)
|
||||
collection: Collection identifier (optional)
|
||||
limit: Maximum results to return (default: 100)
|
||||
batch_size: Triples per batch (default: 20)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Yields:
|
||||
List[Dict]: Batches of triples in wire format
|
||||
|
||||
Example:
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
for batch in flow.triples_query_stream(
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
):
|
||||
for triple in batch:
|
||||
print(triple["s"], triple["p"], triple["o"])
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"limit": limit,
|
||||
"streaming": True,
|
||||
"batch-size": batch_size,
|
||||
}
|
||||
if s is not None:
|
||||
request["s"] = str(s)
|
||||
if p is not None:
|
||||
request["p"] = str(p)
|
||||
if o is not None:
|
||||
request["o"] = str(o)
|
||||
if user is not None:
|
||||
request["user"] = user
|
||||
if collection is not None:
|
||||
request["collection"] = collection
|
||||
request.update(kwargs)
|
||||
|
||||
# Use raw streaming - yields response dicts directly without parsing
|
||||
for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True):
|
||||
# Response is {"response": [...triples...]} from translator
|
||||
if isinstance(response, dict) and "response" in response:
|
||||
yield response["response"]
|
||||
else:
|
||||
yield response
|
||||
|
||||
def rows_query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
|
|||
|
|
@ -22,11 +22,19 @@ def to_value(x):
|
|||
|
||||
|
||||
def from_value(x):
|
||||
"""Convert Uri or Literal to schema Term."""
|
||||
"""Convert Uri, Literal, or string to schema Term."""
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, Uri):
|
||||
return Term(type=IRI, iri=str(x))
|
||||
elif isinstance(x, Literal):
|
||||
return Term(type=LITERAL, value=str(x))
|
||||
elif isinstance(x, str):
|
||||
# Detect IRIs by common prefixes
|
||||
if x.startswith("http://") or x.startswith("https://") or x.startswith("urn:"):
|
||||
return Term(type=IRI, iri=x)
|
||||
else:
|
||||
return Term(type=LITERAL, value=x)
|
||||
else:
|
||||
return Term(type=LITERAL, value=str(x))
|
||||
|
||||
|
|
@ -57,6 +65,61 @@ class TriplesClient(RequestResponse):
|
|||
|
||||
return triples
|
||||
|
||||
async def query_stream(self, s=None, p=None, o=None, limit=20,
|
||||
user="trustgraph", collection="default",
|
||||
batch_size=20, timeout=30,
|
||||
batch_callback=None):
|
||||
"""
|
||||
Streaming triple query - calls callback for each batch as it arrives.
|
||||
|
||||
Args:
|
||||
s, p, o: Triple pattern (None for wildcard)
|
||||
limit: Maximum total triples to return
|
||||
user: User/keyspace
|
||||
collection: Collection name
|
||||
batch_size: Triples per batch
|
||||
timeout: Request timeout in seconds
|
||||
batch_callback: Async callback(batch, is_final) called for each batch
|
||||
|
||||
Returns:
|
||||
List[Triple]: All triples (flattened) if no callback provided
|
||||
"""
|
||||
all_triples = []
|
||||
|
||||
async def recipient(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
batch = [
|
||||
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
|
||||
for v in resp.triples
|
||||
]
|
||||
|
||||
if batch_callback:
|
||||
await batch_callback(batch, resp.is_final)
|
||||
else:
|
||||
all_triples.extend(batch)
|
||||
|
||||
return resp.is_final
|
||||
|
||||
await self.request(
|
||||
TriplesQueryRequest(
|
||||
s=from_value(s),
|
||||
p=from_value(p),
|
||||
o=from_value(o),
|
||||
limit=limit,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=True,
|
||||
batch_size=batch_size,
|
||||
),
|
||||
timeout=timeout,
|
||||
recipient=recipient,
|
||||
)
|
||||
|
||||
if not batch_callback:
|
||||
return all_triples
|
||||
|
||||
class TriplesClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
|
|
|
|||
|
|
@ -52,13 +52,23 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
logger.debug(f"Handling triples query request {id}...")
|
||||
|
||||
triples = await self.query_triples(request)
|
||||
|
||||
logger.debug("Sending triples query response...")
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
logger.debug("Triples query request completed")
|
||||
if request.streaming:
|
||||
# Streaming mode: send batches
|
||||
async for batch, is_final in self.query_triples_stream(request):
|
||||
r = TriplesQueryResponse(
|
||||
triples=batch,
|
||||
error=None,
|
||||
is_final=is_final,
|
||||
)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
logger.debug("Triples query streaming completed")
|
||||
else:
|
||||
# Non-streaming mode: single response
|
||||
triples = await self.query_triples(request)
|
||||
logger.debug("Sending triples query response...")
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
logger.debug("Triples query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
@ -76,6 +86,24 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
async def query_triples_stream(self, request):
|
||||
"""
|
||||
Streaming query - yields (batch, is_final) tuples.
|
||||
Default implementation batches results from query_triples.
|
||||
Override for true streaming from backend.
|
||||
"""
|
||||
triples = await self.query_triples(request)
|
||||
batch_size = request.batch_size if request.batch_size > 0 else 20
|
||||
|
||||
for i in range(0, len(triples), batch_size):
|
||||
batch = triples[i:i + batch_size]
|
||||
is_final = (i + batch_size >= len(triples))
|
||||
yield batch, is_final
|
||||
|
||||
# Handle empty result
|
||||
if len(triples) == 0:
|
||||
yield [], True
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -23,14 +23,18 @@ class TriplesQueryRequestTranslator(MessageTranslator):
|
|||
g=g,
|
||||
limit=int(data.get("limit", 10000)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
collection=data.get("collection", "default"),
|
||||
streaming=data.get("streaming", False),
|
||||
batch_size=int(data.get("batch-size", 20)),
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
"collection": obj.collection,
|
||||
"streaming": obj.streaming,
|
||||
"batch-size": obj.batch_size,
|
||||
}
|
||||
|
||||
if obj.s:
|
||||
|
|
@ -61,4 +65,4 @@ class TriplesQueryResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
return self.from_pulsar(obj), obj.is_final
|
||||
|
|
@ -38,11 +38,14 @@ class TriplesQueryRequest:
|
|||
o: Term | None = None
|
||||
g: str | None = None # Graph IRI. None=default graph, "*"=all graphs
|
||||
limit: int = 0
|
||||
streaming: bool = False # Enable streaming mode (multiple batched responses)
|
||||
batch_size: int = 20 # Triples per batch in streaming mode
|
||||
|
||||
@dataclass
|
||||
class TriplesQueryResponse:
|
||||
error: Error | None = None
|
||||
triples: list[Triple] = field(default_factory=list)
|
||||
is_final: bool = True # False for intermediate batches in streaming mode
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue