mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-28 08:45:13 +02:00
Merge pull request #953 from trustgraph-ai/release/v2.5
release/v2.5 -> master
This commit is contained in:
commit
36eadbda3a
22 changed files with 2632 additions and 1140 deletions
|
|
@ -25,7 +25,7 @@ BUCKET_URL = "https://storage.googleapis.com/trustgraph-library"
|
||||||
INDEX_URL = f"{BUCKET_URL}/index.json"
|
INDEX_URL = f"{BUCKET_URL}/index.json"
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
|
default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
|
||||||
default_user = "trustgraph"
|
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,7 +113,7 @@ def convert_metadata(metadata_json):
|
||||||
return triples
|
return triples
|
||||||
|
|
||||||
|
|
||||||
def load_document(api, user, doc_entry):
|
def load_document(api, doc_entry):
|
||||||
"""Fetch metadata and content for a document, then load into TrustGraph."""
|
"""Fetch metadata and content for a document, then load into TrustGraph."""
|
||||||
doc_id = doc_entry["id"]
|
doc_id = doc_entry["id"]
|
||||||
title = doc_entry["title"]
|
title = doc_entry["title"]
|
||||||
|
|
@ -133,7 +133,6 @@ def load_document(api, user, doc_entry):
|
||||||
api.add_document(
|
api.add_document(
|
||||||
id=doc["id"],
|
id=doc["id"],
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
user=user,
|
|
||||||
kind=doc["kind"],
|
kind=doc["kind"],
|
||||||
title=doc["title"],
|
title=doc["title"],
|
||||||
comments=doc["comments"],
|
comments=doc["comments"],
|
||||||
|
|
@ -144,12 +143,12 @@ def load_document(api, user, doc_entry):
|
||||||
print(f" done.")
|
print(f" done.")
|
||||||
|
|
||||||
|
|
||||||
def load_documents(api, user, docs):
|
def load_documents(api, docs):
|
||||||
"""Load a list of documents."""
|
"""Load a list of documents."""
|
||||||
print(f"Loading {len(docs)} document(s)...\n")
|
print(f"Loading {len(docs)} document(s)...\n")
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
try:
|
try:
|
||||||
load_document(api, user, doc)
|
load_document(api, doc)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" FAILED: {e}", file=sys.stderr)
|
print(f" FAILED: {e}", file=sys.stderr)
|
||||||
print()
|
print()
|
||||||
|
|
@ -166,8 +165,8 @@ def main():
|
||||||
help=f"TrustGraph API URL (default: {default_url})",
|
help=f"TrustGraph API URL (default: {default_url})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-U", "--user", default=default_user,
|
"-w", "--workspace", default=default_workspace,
|
||||||
help=f"User ID (default: {default_user})",
|
help=f"Workspace (default: {default_workspace})",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-t", "--token", default=default_token,
|
"-t", "--token", default=default_token,
|
||||||
|
|
@ -212,22 +211,22 @@ def main():
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load commands need the API
|
# Load commands need the API
|
||||||
api = Api(args.url, token=args.token).library()
|
api = Api(args.url, token=args.token, workspace=args.workspace).library()
|
||||||
|
|
||||||
if args.command == "load-all":
|
if args.command == "load-all":
|
||||||
load_documents(api, args.user, index)
|
load_documents(api, index)
|
||||||
|
|
||||||
elif args.command == "load-doc":
|
elif args.command == "load-doc":
|
||||||
matches = [d for d in index if str(d.get("id")) == args.id]
|
matches = [d for d in index if str(d.get("id")) == args.id]
|
||||||
if not matches:
|
if not matches:
|
||||||
print(f"No document with ID '{args.id}' found.", file=sys.stderr)
|
print(f"No document with ID '{args.id}' found.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
load_documents(api, args.user, matches)
|
load_documents(api, matches)
|
||||||
|
|
||||||
elif args.command == "load-match":
|
elif args.command == "load-match":
|
||||||
results = search_index(index, args.query)
|
results = search_index(index, args.query)
|
||||||
if results:
|
if results:
|
||||||
load_documents(api, args.user, results)
|
load_documents(api, results)
|
||||||
else:
|
else:
|
||||||
print("No matches found.", file=sys.stderr)
|
print("No matches found.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
|
||||||
|
|
@ -3,170 +3,237 @@
|
||||||
WebSocket Relay Test Harness
|
WebSocket Relay Test Harness
|
||||||
|
|
||||||
This script creates a relay server with two WebSocket endpoints:
|
This script creates a relay server with two WebSocket endpoints:
|
||||||
- /in - for test clients to connect to
|
- /in - for test clients to connect to (speaks api-gateway protocol)
|
||||||
- /out - for reverse gateway to connect to
|
- /out - for reverse gateway to connect to (speaks rev-gateway protocol)
|
||||||
|
|
||||||
Messages are bidirectionally relayed between the two connections.
|
Clients on /in authenticate with a first-frame auth message:
|
||||||
|
{"type": "auth", "token": "..."}
|
||||||
|
|
||||||
|
The relay stores the token and injects it into each subsequent message
|
||||||
|
before forwarding to /out. Responses from /out are forwarded back to
|
||||||
|
the originating /in connection unchanged.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python websocket_relay.py [--port PORT] [--host HOST]
|
python websocket_relay.py [--port PORT] [--host HOST]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
from aiohttp import web, WSMsgType
|
from aiohttp import web, WSMsgType
|
||||||
import weakref
|
from typing import Dict, Optional
|
||||||
from typing import Optional, Set
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
)
|
)
|
||||||
logger = logging.getLogger("websocket_relay")
|
logger = logging.getLogger("websocket_relay")
|
||||||
|
|
||||||
|
|
||||||
|
class InConnection:
|
||||||
|
def __init__(self, ws, conn_id):
|
||||||
|
self.ws = ws
|
||||||
|
self.conn_id = conn_id
|
||||||
|
self.token: Optional[str] = None
|
||||||
|
self.authenticated = False
|
||||||
|
|
||||||
|
|
||||||
class WebSocketRelay:
|
class WebSocketRelay:
|
||||||
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.in_connections: Set = weakref.WeakSet()
|
self.in_connections: Dict[str, InConnection] = {}
|
||||||
self.out_connections: Set = weakref.WeakSet()
|
self.out_connections: set = set()
|
||||||
|
self._conn_counter = 0
|
||||||
|
|
||||||
|
def _next_conn_id(self):
|
||||||
|
self._conn_counter += 1
|
||||||
|
return f"conn-{self._conn_counter}"
|
||||||
|
|
||||||
async def handle_in_connection(self, request):
|
async def handle_in_connection(self, request):
|
||||||
"""Handle incoming connections on /in endpoint"""
|
|
||||||
ws = web.WebSocketResponse()
|
ws = web.WebSocketResponse()
|
||||||
await ws.prepare(request)
|
await ws.prepare(request)
|
||||||
|
|
||||||
self.in_connections.add(ws)
|
conn_id = self._next_conn_id()
|
||||||
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
conn = InConnection(ws, conn_id)
|
||||||
|
self.in_connections[conn_id] = conn
|
||||||
|
logger.info(
|
||||||
|
f"New 'in' connection {conn_id}. "
|
||||||
|
f"Total in: {len(self.in_connections)}, "
|
||||||
|
f"out: {len(self.out_connections)}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
if msg.type == WSMsgType.TEXT:
|
if msg.type == WSMsgType.TEXT:
|
||||||
data = msg.data
|
await self._handle_in_message(conn, msg.data)
|
||||||
logger.info(f"IN → OUT: {data}")
|
|
||||||
await self._forward_to_out(data)
|
|
||||||
elif msg.type == WSMsgType.BINARY:
|
|
||||||
data = msg.data
|
|
||||||
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
|
|
||||||
await self._forward_to_out(data, binary=True)
|
|
||||||
elif msg.type == WSMsgType.ERROR:
|
elif msg.type == WSMsgType.ERROR:
|
||||||
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}")
|
logger.error(
|
||||||
|
f"WebSocket error on 'in' connection "
|
||||||
|
f"{conn_id}: {ws.exception()}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in 'in' connection handler: {e}")
|
logger.error(
|
||||||
|
f"Error in 'in' connection {conn_id}: {e}"
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
del self.in_connections[conn_id]
|
||||||
|
logger.info(
|
||||||
|
f"'in' connection {conn_id} closed. "
|
||||||
|
f"Remaining in: {len(self.in_connections)}, "
|
||||||
|
f"out: {len(self.out_connections)}"
|
||||||
|
)
|
||||||
|
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
async def _handle_in_message(self, conn, data):
|
||||||
|
try:
|
||||||
|
message = json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"{conn.conn_id}: received non-JSON message"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(message, dict) and message.get("type") == "auth":
|
||||||
|
conn.token = message.get("token", "")
|
||||||
|
conn.authenticated = True
|
||||||
|
logger.info(f"{conn.conn_id}: authenticated")
|
||||||
|
await conn.ws.send_json({
|
||||||
|
"type": "auth-ok",
|
||||||
|
"workspace": "relayed",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
if not conn.authenticated:
|
||||||
|
await conn.ws.send_json({
|
||||||
|
"error": {
|
||||||
|
"message": "auth required",
|
||||||
|
"type": "auth-required",
|
||||||
|
},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
message["token"] = conn.token
|
||||||
|
message["_relay_conn"] = conn.conn_id
|
||||||
|
|
||||||
|
forwarded = json.dumps(message)
|
||||||
|
logger.info(f"IN {conn.conn_id} → OUT: {forwarded}")
|
||||||
|
await self._forward_to_out(forwarded)
|
||||||
|
|
||||||
async def handle_out_connection(self, request):
|
async def handle_out_connection(self, request):
|
||||||
"""Handle outgoing connections on /out endpoint"""
|
|
||||||
ws = web.WebSocketResponse()
|
ws = web.WebSocketResponse()
|
||||||
await ws.prepare(request)
|
await ws.prepare(request)
|
||||||
|
|
||||||
self.out_connections.add(ws)
|
self.out_connections.add(ws)
|
||||||
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
logger.info(
|
||||||
|
f"New 'out' connection. "
|
||||||
|
f"Total in: {len(self.in_connections)}, "
|
||||||
|
f"out: {len(self.out_connections)}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
if msg.type == WSMsgType.TEXT:
|
if msg.type == WSMsgType.TEXT:
|
||||||
data = msg.data
|
await self._handle_out_message(msg.data)
|
||||||
logger.info(f"OUT → IN: {data}")
|
|
||||||
await self._forward_to_in(data)
|
|
||||||
elif msg.type == WSMsgType.BINARY:
|
|
||||||
data = msg.data
|
|
||||||
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
|
|
||||||
await self._forward_to_in(data, binary=True)
|
|
||||||
elif msg.type == WSMsgType.ERROR:
|
elif msg.type == WSMsgType.ERROR:
|
||||||
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
|
logger.error(
|
||||||
|
f"WebSocket error on 'out' connection: "
|
||||||
|
f"{ws.exception()}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in 'out' connection handler: {e}")
|
logger.error(f"Error in 'out' connection: {e}")
|
||||||
finally:
|
finally:
|
||||||
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
self.out_connections.discard(ws)
|
||||||
|
logger.info(
|
||||||
|
f"'out' connection closed. "
|
||||||
|
f"Remaining in: {len(self.in_connections)}, "
|
||||||
|
f"out: {len(self.out_connections)}"
|
||||||
|
)
|
||||||
|
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
async def _forward_to_out(self, data, binary=False):
|
async def _handle_out_message(self, data):
|
||||||
"""Forward message from 'in' to all 'out' connections"""
|
try:
|
||||||
if not self.out_connections:
|
message = json.loads(data)
|
||||||
logger.warning("No 'out' connections available to forward message")
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("OUT: received non-JSON message")
|
||||||
return
|
return
|
||||||
|
|
||||||
closed_connections = []
|
conn_id = message.pop("_relay_conn", None)
|
||||||
|
|
||||||
|
forwarded = json.dumps(message)
|
||||||
|
logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}")
|
||||||
|
|
||||||
|
if conn_id and conn_id in self.in_connections:
|
||||||
|
conn = self.in_connections[conn_id]
|
||||||
|
try:
|
||||||
|
if not conn.ws.closed:
|
||||||
|
await conn.ws.send_str(forwarded)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error forwarding to 'in' {conn_id}: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._broadcast_to_in(forwarded)
|
||||||
|
|
||||||
|
async def _broadcast_to_in(self, data):
|
||||||
|
closed = []
|
||||||
|
for conn_id, conn in list(self.in_connections.items()):
|
||||||
|
try:
|
||||||
|
if conn.ws.closed:
|
||||||
|
closed.append(conn_id)
|
||||||
|
continue
|
||||||
|
await conn.ws.send_str(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error broadcasting to 'in' {conn_id}: {e}"
|
||||||
|
)
|
||||||
|
closed.append(conn_id)
|
||||||
|
for conn_id in closed:
|
||||||
|
self.in_connections.pop(conn_id, None)
|
||||||
|
|
||||||
|
async def _forward_to_out(self, data):
|
||||||
|
closed = []
|
||||||
for ws in list(self.out_connections):
|
for ws in list(self.out_connections):
|
||||||
try:
|
try:
|
||||||
if ws.closed:
|
if ws.closed:
|
||||||
closed_connections.append(ws)
|
closed.append(ws)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if binary:
|
|
||||||
await ws.send_bytes(data)
|
|
||||||
else:
|
|
||||||
await ws.send_str(data)
|
await ws.send_str(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error forwarding to 'out' connection: {e}")
|
logger.error(f"Error forwarding to 'out': {e}")
|
||||||
closed_connections.append(ws)
|
closed.append(ws)
|
||||||
|
for ws in closed:
|
||||||
# Clean up closed connections
|
|
||||||
for ws in closed_connections:
|
|
||||||
if ws in self.out_connections:
|
|
||||||
self.out_connections.discard(ws)
|
self.out_connections.discard(ws)
|
||||||
|
|
||||||
async def _forward_to_in(self, data, binary=False):
|
|
||||||
"""Forward message from 'out' to all 'in' connections"""
|
|
||||||
if not self.in_connections:
|
|
||||||
logger.warning("No 'in' connections available to forward message")
|
|
||||||
return
|
|
||||||
|
|
||||||
closed_connections = []
|
|
||||||
for ws in list(self.in_connections):
|
|
||||||
try:
|
|
||||||
if ws.closed:
|
|
||||||
closed_connections.append(ws)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if binary:
|
|
||||||
await ws.send_bytes(data)
|
|
||||||
else:
|
|
||||||
await ws.send_str(data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error forwarding to 'in' connection: {e}")
|
|
||||||
closed_connections.append(ws)
|
|
||||||
|
|
||||||
# Clean up closed connections
|
|
||||||
for ws in closed_connections:
|
|
||||||
if ws in self.in_connections:
|
|
||||||
self.in_connections.discard(ws)
|
|
||||||
|
|
||||||
async def create_app(relay):
|
async def create_app(relay):
|
||||||
"""Create the web application with routes"""
|
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
|
|
||||||
# Add routes
|
app.router.add_get('/in/api/v1/socket', relay.handle_in_connection)
|
||||||
app.router.add_get('/in', relay.handle_in_connection)
|
|
||||||
app.router.add_get('/out', relay.handle_out_connection)
|
app.router.add_get('/out', relay.handle_out_connection)
|
||||||
|
|
||||||
# Add a simple status endpoint
|
|
||||||
async def status(request):
|
async def status(request):
|
||||||
status_info = {
|
return web.json_response({
|
||||||
'in_connections': len(relay.in_connections),
|
'in_connections': len(relay.in_connections),
|
||||||
'out_connections': len(relay.out_connections),
|
'out_connections': len(relay.out_connections),
|
||||||
'status': 'running'
|
'status': 'running',
|
||||||
}
|
})
|
||||||
return web.json_response(status_info)
|
|
||||||
|
|
||||||
app.router.add_get('/status', status)
|
app.router.add_get('/status', status)
|
||||||
app.router.add_get('/', status) # Root also shows status
|
app.router.add_get('/', status)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="WebSocket Relay Test Harness"
|
description="WebSocket Relay Test Harness"
|
||||||
|
|
@ -174,18 +241,18 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--host',
|
'--host',
|
||||||
default='localhost',
|
default='localhost',
|
||||||
help='Host to bind to (default: localhost)'
|
help='Host to bind to (default: localhost)',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--port',
|
'--port',
|
||||||
type=int,
|
type=int,
|
||||||
default=8080,
|
default=8080,
|
||||||
help='Port to bind to (default: 8080)'
|
help='Port to bind to (default: 8080)',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--verbose', '-v',
|
'--verbose', '-v',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Enable verbose logging'
|
help='Enable verbose logging',
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -196,15 +263,18 @@ def main():
|
||||||
relay = WebSocketRelay()
|
relay = WebSocketRelay()
|
||||||
|
|
||||||
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
|
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
|
||||||
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in")
|
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket")
|
||||||
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
|
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
|
||||||
print(f" Status: http://{args.host}:{args.port}/status")
|
print(f" Status: http://{args.host}:{args.port}/status")
|
||||||
print()
|
print()
|
||||||
print("Usage:")
|
print("Client protocol (same as api-gateway):")
|
||||||
print(f" Test client connects to: ws://{args.host}:{args.port}/in")
|
print(' 1. Connect to /in/api/v1/socket')
|
||||||
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out")
|
print(' 2. Send: {"type": "auth", "token": "tg_..."}')
|
||||||
|
print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}')
|
||||||
|
print(' 4. Send requests as normal')
|
||||||
|
|
||||||
web.run_app(create_app(relay), host=args.host, port=args.port)
|
web.run_app(create_app(relay), host=args.host, port=args.port)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
|
||||||
max_concurrent = 0
|
max_concurrent = 0
|
||||||
processing_event = asyncio.Event()
|
processing_event = asyncio.Event()
|
||||||
|
|
||||||
async def slow_process(message):
|
async def slow_process(message, sender):
|
||||||
nonlocal concurrent_count, max_concurrent
|
nonlocal concurrent_count, max_concurrent
|
||||||
concurrent_count += 1
|
concurrent_count += 1
|
||||||
max_concurrent = max(max_concurrent, concurrent_count)
|
max_concurrent = max(max_concurrent, concurrent_count)
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
concurrent_count -= 1
|
concurrent_count -= 1
|
||||||
return {"id": message.get("id"), "response": {"ok": True}}
|
|
||||||
|
|
||||||
dispatcher._process_message = slow_process
|
dispatcher._process_message = slow_process
|
||||||
|
|
||||||
|
sender = AsyncMock()
|
||||||
|
|
||||||
# Launch more tasks than max_workers
|
# Launch more tasks than max_workers
|
||||||
messages = [
|
messages = [
|
||||||
{"id": f"msg-{i}", "service": "test", "request": {}}
|
{"id": f"msg-{i}", "service": "test", "request": {}}
|
||||||
|
|
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
|
||||||
]
|
]
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.create_task(dispatcher.handle_message(m))
|
asyncio.create_task(dispatcher.handle_message(m, sender))
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
|
||||||
|
|
||||||
original_process = dispatcher._process_message
|
original_process = dispatcher._process_message
|
||||||
|
|
||||||
async def tracking_process(message):
|
async def tracking_process(message, sender):
|
||||||
nonlocal task_was_tracked
|
nonlocal task_was_tracked
|
||||||
# During processing, our task should be in active_tasks
|
# During processing, our task should be in active_tasks
|
||||||
if len(dispatcher.active_tasks) > 0:
|
if len(dispatcher.active_tasks) > 0:
|
||||||
task_was_tracked = True
|
task_was_tracked = True
|
||||||
return {"id": message.get("id"), "response": {"ok": True}}
|
|
||||||
|
|
||||||
dispatcher._process_message = tracking_process
|
dispatcher._process_message = tracking_process
|
||||||
|
|
||||||
await dispatcher.handle_message(
|
await dispatcher.handle_message(
|
||||||
{"id": "test", "service": "test", "request": {}}
|
{"id": "test", "service": "test", "request": {}},
|
||||||
|
AsyncMock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert task_was_tracked
|
assert task_was_tracked
|
||||||
|
|
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
|
||||||
"""Semaphore should be released even if processing raises."""
|
"""Semaphore should be released even if processing raises."""
|
||||||
dispatcher = MessageDispatcher(max_workers=2)
|
dispatcher = MessageDispatcher(max_workers=2)
|
||||||
|
|
||||||
async def failing_process(message):
|
async def failing_process(message, sender):
|
||||||
raise RuntimeError("process failed")
|
raise RuntimeError("process failed")
|
||||||
|
|
||||||
dispatcher._process_message = failing_process
|
dispatcher._process_message = failing_process
|
||||||
|
|
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
|
||||||
# Should not deadlock — semaphore must be released on error
|
# Should not deadlock — semaphore must be released on error
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
await dispatcher.handle_message(
|
await dispatcher.handle_message(
|
||||||
{"id": "test", "service": "test", "request": {}}
|
{"id": "test", "service": "test", "request": {}},
|
||||||
|
AsyncMock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Semaphore should be back at max
|
# Semaphore should be back at max
|
||||||
|
|
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
|
||||||
|
|
||||||
order = []
|
order = []
|
||||||
|
|
||||||
async def ordered_process(message):
|
async def ordered_process(message, sender):
|
||||||
msg_id = message["id"]
|
msg_id = message["id"]
|
||||||
order.append(f"start-{msg_id}")
|
order.append(f"start-{msg_id}")
|
||||||
await asyncio.sleep(0.02)
|
await asyncio.sleep(0.02)
|
||||||
order.append(f"end-{msg_id}")
|
order.append(f"end-{msg_id}")
|
||||||
return {"id": msg_id, "response": {"ok": True}}
|
|
||||||
|
|
||||||
dispatcher._process_message = ordered_process
|
dispatcher._process_message = ordered_process
|
||||||
|
|
||||||
|
sender = AsyncMock()
|
||||||
|
|
||||||
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
|
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
|
||||||
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
|
tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# With semaphore=1, each message should complete before next starts
|
# With semaphore=1, each message should complete before next starts
|
||||||
|
|
|
||||||
56
tests/unit/test_query/test_ontology_monitoring.py
Normal file
56
tests/unit/test_query/test_ontology_monitoring.py
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
"""
|
||||||
|
Tests for ontology monitoring metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from trustgraph.query.ontology.monitoring import (
|
||||||
|
PerformanceMonitor,
|
||||||
|
_extract_metric_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_metric_label_reads_unquoted_label_value():
|
||||||
|
metric_name = "cache_requests_total{cache_type=entity,component=ontology}"
|
||||||
|
|
||||||
|
assert _extract_metric_label(metric_name, "cache_type") == "entity"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_metric_label_reads_quoted_label_value():
|
||||||
|
metric_name = 'cache_requests_total{cache_type="entity",component="ontology"}'
|
||||||
|
|
||||||
|
assert _extract_metric_label(metric_name, "cache_type") == "entity"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_metric_label_returns_none_when_label_missing():
|
||||||
|
metric_name = "cache_requests_total{component=ontology}"
|
||||||
|
|
||||||
|
assert _extract_metric_label(metric_name, "cache_type") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_performance_report_ignores_counters_without_cache_type_label():
|
||||||
|
monitor = PerformanceMonitor({"enabled": False})
|
||||||
|
monitor.metrics_collector.increment(
|
||||||
|
"cache_requests_total",
|
||||||
|
labels={"component": "ontology"},
|
||||||
|
)
|
||||||
|
monitor.metrics_collector.increment(
|
||||||
|
"cache_type=not_a_label",
|
||||||
|
labels={"component": "ontology"},
|
||||||
|
)
|
||||||
|
monitor.metrics_collector.increment(
|
||||||
|
"cache_requests_total",
|
||||||
|
labels={"cache_type": "entity"},
|
||||||
|
)
|
||||||
|
monitor.metrics_collector.increment(
|
||||||
|
"cache_hits_total",
|
||||||
|
labels={"cache_type": "entity"},
|
||||||
|
)
|
||||||
|
|
||||||
|
report = monitor.get_performance_report()
|
||||||
|
|
||||||
|
assert report["cache_performance"] == {
|
||||||
|
"entity": {
|
||||||
|
"hit_rate": 1.0,
|
||||||
|
"total_requests": 1.0,
|
||||||
|
"total_hits": 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -14,7 +14,7 @@ from rdflib.plugins.sparql.parserutils import CompValue
|
||||||
|
|
||||||
from trustgraph.schema import Term, IRI, LITERAL
|
from trustgraph.schema import Term, IRI, LITERAL
|
||||||
from trustgraph.query.sparql.algebra import (
|
from trustgraph.query.sparql.algebra import (
|
||||||
evaluate, _query_pattern, _eval_bgp,
|
evaluate, materialise, _query_pattern, _eval_bgp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,6 +28,32 @@ def lit(v):
|
||||||
return Term(type=LITERAL, value=v)
|
return Term(type=LITERAL, value=v)
|
||||||
|
|
||||||
|
|
||||||
|
def make_tc(query_return=None, query_side_effect=None):
|
||||||
|
"""Create a mock TriplesClient with both query() and query_gen() support."""
|
||||||
|
tc = AsyncMock()
|
||||||
|
|
||||||
|
if query_side_effect is not None:
|
||||||
|
tc.query.side_effect = query_side_effect
|
||||||
|
|
||||||
|
async def gen_side_effect(**kwargs):
|
||||||
|
results = await query_side_effect(**kwargs)
|
||||||
|
for r in results:
|
||||||
|
yield r
|
||||||
|
|
||||||
|
tc.query_gen = gen_side_effect
|
||||||
|
else:
|
||||||
|
items = query_return or []
|
||||||
|
tc.query.return_value = items
|
||||||
|
|
||||||
|
async def gen(**kwargs):
|
||||||
|
for item in items:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
tc.query_gen = gen
|
||||||
|
|
||||||
|
return tc
|
||||||
|
|
||||||
|
|
||||||
def make_triple(s, p, o):
|
def make_triple(s, p, o):
|
||||||
t = MagicMock()
|
t = MagicMock()
|
||||||
t.s = s
|
t.s = s
|
||||||
|
|
@ -84,6 +110,20 @@ def make_distinct(inner):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def make_filter(inner, expr):
|
||||||
|
node = CompValue("Filter")
|
||||||
|
node.p = inner
|
||||||
|
node.expr = expr
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def make_minus(left, right):
|
||||||
|
node = CompValue("Minus")
|
||||||
|
node.p1 = left
|
||||||
|
node.p2 = right
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
class TestQueryPattern:
|
class TestQueryPattern:
|
||||||
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
|
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
|
||||||
|
|
||||||
|
|
@ -136,15 +176,14 @@ class TestEvalBgp:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_single_pattern_all_variables(self):
|
async def test_single_pattern_all_variables(self):
|
||||||
tc = AsyncMock()
|
|
||||||
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
||||||
tc.query.return_value = [triple]
|
tc = make_tc(query_return=[triple])
|
||||||
|
|
||||||
bgp = make_bgp(
|
bgp = make_bgp(
|
||||||
(Variable("s"), Variable("p"), Variable("o")),
|
(Variable("s"), Variable("p"), Variable("o")),
|
||||||
)
|
)
|
||||||
|
|
||||||
solutions = await evaluate(bgp, tc, collection="default", limit=100)
|
solutions = await materialise(bgp, tc, collection="default", limit=100)
|
||||||
|
|
||||||
assert len(solutions) == 1
|
assert len(solutions) == 1
|
||||||
assert solutions[0]["s"].iri == "http://s"
|
assert solutions[0]["s"].iri == "http://s"
|
||||||
|
|
@ -153,43 +192,37 @@ class TestEvalBgp:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_single_pattern_bound_subject(self):
|
async def test_single_pattern_bound_subject(self):
|
||||||
tc = AsyncMock()
|
tc = make_tc(query_return=[
|
||||||
tc.query.return_value = [
|
|
||||||
make_triple(iri("http://s"), iri("http://p"), lit("val")),
|
make_triple(iri("http://s"), iri("http://p"), lit("val")),
|
||||||
]
|
])
|
||||||
|
|
||||||
bgp = make_bgp(
|
bgp = make_bgp(
|
||||||
(URIRef("http://s"), Variable("p"), Variable("o")),
|
(URIRef("http://s"), Variable("p"), Variable("o")),
|
||||||
)
|
)
|
||||||
|
|
||||||
solutions = await evaluate(bgp, tc, collection="default")
|
solutions = await materialise(bgp, tc, collection="default")
|
||||||
|
|
||||||
tc.query.assert_called_once()
|
assert len(solutions) == 1
|
||||||
kwargs = tc.query.call_args.kwargs
|
|
||||||
assert "workspace" not in kwargs
|
|
||||||
assert kwargs["collection"] == "default"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_bgp_returns_empty_solution(self):
|
async def test_empty_bgp_returns_empty_solution(self):
|
||||||
tc = AsyncMock()
|
tc = make_tc()
|
||||||
|
|
||||||
bgp = make_bgp()
|
bgp = make_bgp()
|
||||||
|
|
||||||
solutions = await evaluate(bgp, tc, collection="default")
|
solutions = await materialise(bgp, tc, collection="default")
|
||||||
|
|
||||||
assert solutions == [{}]
|
assert solutions == [{}]
|
||||||
tc.query.assert_not_called()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_results_returns_empty(self):
|
async def test_no_results_returns_empty(self):
|
||||||
tc = AsyncMock()
|
tc = make_tc(query_return=[])
|
||||||
tc.query.return_value = []
|
|
||||||
|
|
||||||
bgp = make_bgp(
|
bgp = make_bgp(
|
||||||
(Variable("s"), Variable("p"), Variable("o")),
|
(Variable("s"), Variable("p"), Variable("o")),
|
||||||
)
|
)
|
||||||
|
|
||||||
solutions = await evaluate(bgp, tc, collection="default")
|
solutions = await materialise(bgp, tc, collection="default")
|
||||||
|
|
||||||
assert solutions == []
|
assert solutions == []
|
||||||
|
|
||||||
|
|
@ -199,17 +232,16 @@ class TestEvaluate:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_select_query_node(self):
|
async def test_select_query_node(self):
|
||||||
tc = AsyncMock()
|
tc = make_tc(query_return=[
|
||||||
tc.query.return_value = [
|
|
||||||
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
||||||
]
|
])
|
||||||
|
|
||||||
bgp = make_bgp(
|
bgp = make_bgp(
|
||||||
(Variable("s"), Variable("p"), Variable("o")),
|
(Variable("s"), Variable("p"), Variable("o")),
|
||||||
)
|
)
|
||||||
select = make_select(make_project(bgp, ["s", "p"]))
|
select = make_select(make_project(bgp, ["s", "p"]))
|
||||||
|
|
||||||
solutions = await evaluate(select, tc, collection="default")
|
solutions = await materialise(select, tc, collection="default")
|
||||||
|
|
||||||
assert len(solutions) == 1
|
assert len(solutions) == 1
|
||||||
assert "s" in solutions[0]
|
assert "s" in solutions[0]
|
||||||
|
|
@ -220,10 +252,9 @@ class TestEvaluate:
|
||||||
async def test_workspace_never_in_query_calls(self):
|
async def test_workspace_never_in_query_calls(self):
|
||||||
"""Verify that no matter the algebra structure, workspace is never
|
"""Verify that no matter the algebra structure, workspace is never
|
||||||
passed to TriplesClient.query()."""
|
passed to TriplesClient.query()."""
|
||||||
tc = AsyncMock()
|
tc = make_tc(query_return=[
|
||||||
tc.query.return_value = [
|
|
||||||
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
||||||
]
|
])
|
||||||
|
|
||||||
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||||
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
|
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
|
||||||
|
|
@ -231,72 +262,319 @@ class TestEvaluate:
|
||||||
make_union(bgp1, bgp2), ["s", "p", "o"]
|
make_union(bgp1, bgp2), ["s", "p", "o"]
|
||||||
))
|
))
|
||||||
|
|
||||||
await evaluate(tree, tc, collection="test-coll")
|
await materialise(tree, tc, collection="test-coll")
|
||||||
|
|
||||||
for c in tc.query.call_args_list:
|
|
||||||
assert "workspace" not in c.kwargs
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_join(self):
|
async def test_join(self):
|
||||||
tc = AsyncMock()
|
call_count = 0
|
||||||
tc.query.side_effect = [
|
|
||||||
[make_triple(iri("http://a"), iri("http://p"), lit("v"))],
|
async def mock_query(**kwargs):
|
||||||
[make_triple(iri("http://a"), iri("http://q"), lit("w"))],
|
nonlocal call_count
|
||||||
]
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return [make_triple(iri("http://a"), iri("http://p"), lit("v"))]
|
||||||
|
else:
|
||||||
|
return [make_triple(iri("http://a"), iri("http://q"), lit("w"))]
|
||||||
|
|
||||||
|
tc = make_tc(query_side_effect=mock_query)
|
||||||
|
|
||||||
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
|
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
|
||||||
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
|
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
|
||||||
tree = make_join(bgp1, bgp2)
|
tree = make_join(bgp1, bgp2)
|
||||||
|
|
||||||
solutions = await evaluate(tree, tc, collection="default")
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
assert len(solutions) == 1
|
assert len(solutions) == 1
|
||||||
assert solutions[0]["s"].iri == "http://a"
|
assert solutions[0]["s"].iri == "http://a"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_slice(self):
|
async def test_slice(self):
|
||||||
tc = AsyncMock()
|
|
||||||
triples = [
|
triples = [
|
||||||
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
|
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
|
||||||
for i in range(5)
|
for i in range(5)
|
||||||
]
|
]
|
||||||
tc.query.return_value = triples
|
tc = make_tc(query_return=triples)
|
||||||
|
|
||||||
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||||
tree = make_slice(bgp, start=1, length=2)
|
tree = make_slice(bgp, start=1, length=2)
|
||||||
|
|
||||||
solutions = await evaluate(tree, tc, collection="default")
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
assert len(solutions) == 2
|
assert len(solutions) == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_distinct(self):
|
async def test_distinct(self):
|
||||||
tc = AsyncMock()
|
|
||||||
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
||||||
tc.query.return_value = [triple, triple]
|
tc = make_tc(query_return=[triple, triple])
|
||||||
|
|
||||||
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||||
tree = make_distinct(bgp)
|
tree = make_distinct(bgp)
|
||||||
|
|
||||||
solutions = await evaluate(tree, tc, collection="default")
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
assert len(solutions) == 1
|
assert len(solutions) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_minus_removes_matching(self):
|
||||||
|
alice = iri("http://example.com/alice")
|
||||||
|
bob = iri("http://example.com/bob")
|
||||||
|
knows = iri("http://example.com/knows")
|
||||||
|
hates = iri("http://example.com/hates")
|
||||||
|
charlie = iri("http://example.com/charlie")
|
||||||
|
|
||||||
|
left_triple = make_triple(alice, knows, bob)
|
||||||
|
right_triple2 = make_triple(alice, hates, charlie)
|
||||||
|
|
||||||
|
async def mock_query(**kwargs):
|
||||||
|
pred = kwargs.get("p")
|
||||||
|
if pred and pred.iri == "http://example.com/knows":
|
||||||
|
return [left_triple]
|
||||||
|
elif pred and pred.iri == "http://example.com/hates":
|
||||||
|
return [right_triple2]
|
||||||
|
return []
|
||||||
|
|
||||||
|
tc = make_tc(query_side_effect=mock_query)
|
||||||
|
|
||||||
|
left_bgp = make_bgp(
|
||||||
|
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
|
||||||
|
)
|
||||||
|
right_bgp = make_bgp(
|
||||||
|
(Variable("s"), URIRef("http://example.com/hates"), Variable("r"))
|
||||||
|
)
|
||||||
|
|
||||||
|
tree = make_select(
|
||||||
|
make_project(
|
||||||
|
make_minus(left_bgp, right_bgp),
|
||||||
|
["s", "o"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
|
assert len(solutions) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_minus_no_shared_vars_preserves_all(self):
|
||||||
|
alice = iri("http://example.com/alice")
|
||||||
|
bob = iri("http://example.com/bob")
|
||||||
|
|
||||||
|
left_triple = make_triple(alice, iri("http://example.com/p"), bob)
|
||||||
|
|
||||||
|
async def mock_query(**kwargs):
|
||||||
|
pred = kwargs.get("p")
|
||||||
|
if pred and pred.iri == "http://example.com/p":
|
||||||
|
return [left_triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
tc = make_tc(query_side_effect=mock_query)
|
||||||
|
|
||||||
|
left_bgp = make_bgp(
|
||||||
|
(Variable("s"), URIRef("http://example.com/p"), Variable("o"))
|
||||||
|
)
|
||||||
|
right_bgp = make_bgp(
|
||||||
|
(Variable("x"), URIRef("http://example.com/q"), Variable("y"))
|
||||||
|
)
|
||||||
|
|
||||||
|
tree = make_select(
|
||||||
|
make_project(
|
||||||
|
make_minus(left_bgp, right_bgp),
|
||||||
|
["s", "o"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
|
assert len(solutions) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_exists_keeps_matching(self):
|
||||||
|
alice = iri("http://example.com/alice")
|
||||||
|
bob = iri("http://example.com/bob")
|
||||||
|
charlie = iri("http://example.com/charlie")
|
||||||
|
|
||||||
|
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
|
||||||
|
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
|
||||||
|
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
|
||||||
|
|
||||||
|
async def mock_query(**kwargs):
|
||||||
|
pred = kwargs.get("p")
|
||||||
|
if pred and pred.iri == "http://example.com/knows":
|
||||||
|
return [left_triple1, left_triple2]
|
||||||
|
elif pred and pred.iri == "http://example.com/likes":
|
||||||
|
return [exists_triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
tc = make_tc(query_side_effect=mock_query)
|
||||||
|
|
||||||
|
left_bgp = make_bgp(
|
||||||
|
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
|
||||||
|
)
|
||||||
|
exists_bgp = make_bgp(
|
||||||
|
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
|
||||||
|
)
|
||||||
|
|
||||||
|
exists_expr = CompValue("Builtin_EXISTS")
|
||||||
|
exists_expr.graph = exists_bgp
|
||||||
|
|
||||||
|
tree = make_select(
|
||||||
|
make_project(
|
||||||
|
make_filter(left_bgp, exists_expr),
|
||||||
|
["s", "o"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
|
result_objects = [s["o"].iri for s in solutions]
|
||||||
|
assert "http://example.com/bob" in result_objects
|
||||||
|
assert "http://example.com/charlie" not in result_objects
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_not_exists_removes_matching(self):
|
||||||
|
alice = iri("http://example.com/alice")
|
||||||
|
bob = iri("http://example.com/bob")
|
||||||
|
charlie = iri("http://example.com/charlie")
|
||||||
|
|
||||||
|
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
|
||||||
|
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
|
||||||
|
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
|
||||||
|
|
||||||
|
async def mock_query(**kwargs):
|
||||||
|
pred = kwargs.get("p")
|
||||||
|
if pred and pred.iri == "http://example.com/knows":
|
||||||
|
return [left_triple1, left_triple2]
|
||||||
|
elif pred and pred.iri == "http://example.com/likes":
|
||||||
|
return [exists_triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
tc = make_tc(query_side_effect=mock_query)
|
||||||
|
|
||||||
|
left_bgp = make_bgp(
|
||||||
|
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
|
||||||
|
)
|
||||||
|
exists_bgp = make_bgp(
|
||||||
|
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
|
||||||
|
)
|
||||||
|
|
||||||
|
not_exists_expr = CompValue("Builtin_NOTEXISTS")
|
||||||
|
not_exists_expr.graph = exists_bgp
|
||||||
|
|
||||||
|
tree = make_select(
|
||||||
|
make_project(
|
||||||
|
make_filter(left_bgp, not_exists_expr),
|
||||||
|
["s", "o"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
|
result_objects = [s["o"].iri for s in solutions]
|
||||||
|
assert "http://example.com/charlie" in result_objects
|
||||||
|
assert "http://example.com/bob" not in result_objects
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_values_uses_bind_join(self):
|
||||||
|
"""When VALUES is joined with a BGP, the bind join should pass
|
||||||
|
the VALUES bindings into the BGP evaluation so the triple store
|
||||||
|
query is selective (not a wildcard)."""
|
||||||
|
alice = iri("http://example.com/alice")
|
||||||
|
bob = iri("http://example.com/bob")
|
||||||
|
knows = iri("http://example.com/knows")
|
||||||
|
|
||||||
|
queries_issued = []
|
||||||
|
|
||||||
|
async def mock_query(**kwargs):
|
||||||
|
queries_issued.append(kwargs)
|
||||||
|
s, p = kwargs.get("s"), kwargs.get("p")
|
||||||
|
if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows":
|
||||||
|
return [make_triple(alice, knows, bob)]
|
||||||
|
return []
|
||||||
|
|
||||||
|
tc = make_tc(query_side_effect=mock_query)
|
||||||
|
|
||||||
|
# VALUES ?s { <alice> }
|
||||||
|
values_node = CompValue("values")
|
||||||
|
values_node.var = [Variable("s")]
|
||||||
|
values_node.value = [[URIRef("http://example.com/alice")]]
|
||||||
|
values_node.res = None
|
||||||
|
|
||||||
|
to_multiset = CompValue("ToMultiSet")
|
||||||
|
to_multiset.p = values_node
|
||||||
|
|
||||||
|
bgp = make_bgp(
|
||||||
|
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
|
||||||
|
)
|
||||||
|
|
||||||
|
tree = make_join(to_multiset, bgp)
|
||||||
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
|
assert len(solutions) == 1
|
||||||
|
assert solutions[0]["s"].iri == "http://example.com/alice"
|
||||||
|
assert solutions[0]["o"].iri == "http://example.com/bob"
|
||||||
|
|
||||||
|
# The key assertion: the BGP query should have received
|
||||||
|
# s=alice (bound from VALUES), NOT s=None (wildcard)
|
||||||
|
assert len(queries_issued) == 1
|
||||||
|
assert queries_issued[0]["s"] is not None
|
||||||
|
assert queries_issued[0]["s"].iri == "http://example.com/alice"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_values_multiple_bindings(self):
|
||||||
|
"""Bind join with multiple VALUES bindings."""
|
||||||
|
alice = iri("http://example.com/alice")
|
||||||
|
bob = iri("http://example.com/bob")
|
||||||
|
knows = iri("http://example.com/knows")
|
||||||
|
charlie = iri("http://example.com/charlie")
|
||||||
|
|
||||||
|
async def mock_query(**kwargs):
|
||||||
|
s = kwargs.get("s")
|
||||||
|
if s and s.iri == "http://example.com/alice":
|
||||||
|
return [make_triple(alice, knows, bob)]
|
||||||
|
elif s and s.iri == "http://example.com/bob":
|
||||||
|
return [make_triple(bob, knows, charlie)]
|
||||||
|
return []
|
||||||
|
|
||||||
|
tc = make_tc(query_side_effect=mock_query)
|
||||||
|
|
||||||
|
values_node = CompValue("values")
|
||||||
|
values_node.var = [Variable("s")]
|
||||||
|
values_node.value = [
|
||||||
|
[URIRef("http://example.com/alice")],
|
||||||
|
[URIRef("http://example.com/bob")],
|
||||||
|
]
|
||||||
|
values_node.res = None
|
||||||
|
|
||||||
|
to_multiset = CompValue("ToMultiSet")
|
||||||
|
to_multiset.p = values_node
|
||||||
|
|
||||||
|
bgp = make_bgp(
|
||||||
|
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
|
||||||
|
)
|
||||||
|
|
||||||
|
tree = make_join(to_multiset, bgp)
|
||||||
|
solutions = await materialise(tree, tc, collection="default")
|
||||||
|
|
||||||
|
assert len(solutions) == 2
|
||||||
|
subjects = {s["s"].iri for s in solutions}
|
||||||
|
assert subjects == {
|
||||||
|
"http://example.com/alice",
|
||||||
|
"http://example.com/bob",
|
||||||
|
}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unsupported_node_returns_empty_solution(self):
|
async def test_unsupported_node_returns_empty_solution(self):
|
||||||
tc = AsyncMock()
|
tc = make_tc()
|
||||||
|
|
||||||
node = CompValue("SomethingUnknown")
|
node = CompValue("SomethingUnknown")
|
||||||
|
|
||||||
solutions = await evaluate(node, tc, collection="default")
|
solutions = await materialise(node, tc, collection="default")
|
||||||
|
|
||||||
assert solutions == [{}]
|
assert solutions == [{}]
|
||||||
tc.query.assert_not_called()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_compvalue_returns_empty_solution(self):
|
async def test_non_compvalue_returns_empty_solution(self):
|
||||||
tc = AsyncMock()
|
tc = make_tc()
|
||||||
|
|
||||||
solutions = await evaluate("not a node", tc, collection="default")
|
solutions = await materialise("not a node", tc, collection="default")
|
||||||
|
|
||||||
assert solutions == [{}]
|
assert solutions == [{}]
|
||||||
|
|
|
||||||
|
|
@ -300,6 +300,438 @@ class TestBuiltinFunctions:
|
||||||
flags=None)
|
flags=None)
|
||||||
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
||||||
|
|
||||||
|
def test_substr_three_args(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("SUBSTR",
|
||||||
|
arg=Variable("x"),
|
||||||
|
start=Literal(1),
|
||||||
|
length=Literal(4))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.value == "2024"
|
||||||
|
|
||||||
|
def test_substr_two_args(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("SUBSTR",
|
||||||
|
arg=Variable("x"),
|
||||||
|
start=Literal(6),
|
||||||
|
length=None)
|
||||||
|
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.value == "03-15"
|
||||||
|
|
||||||
|
def test_substr_middle(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("SUBSTR",
|
||||||
|
arg=Variable("x"),
|
||||||
|
start=Literal(6),
|
||||||
|
length=Literal(2))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.value == "03"
|
||||||
|
|
||||||
|
def test_substr_null_start(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("SUBSTR",
|
||||||
|
arg=Variable("x"),
|
||||||
|
start=Variable("missing"),
|
||||||
|
length=None)
|
||||||
|
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_year(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("YEAR", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
|
||||||
|
)
|
||||||
|
assert result == 2024
|
||||||
|
|
||||||
|
def test_month(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("MONTH", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
|
||||||
|
)
|
||||||
|
assert result == 3
|
||||||
|
|
||||||
|
def test_day(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("DAY", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
|
||||||
|
)
|
||||||
|
assert result == 15
|
||||||
|
|
||||||
|
def test_hours(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("HOURS", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
|
||||||
|
)
|
||||||
|
assert result == 10
|
||||||
|
|
||||||
|
def test_minutes(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("MINUTES", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
|
||||||
|
)
|
||||||
|
assert result == 30
|
||||||
|
|
||||||
|
def test_seconds(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("SECONDS", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
|
||||||
|
)
|
||||||
|
assert result == 45
|
||||||
|
|
||||||
|
def test_year_from_datetime(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("YEAR", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
|
||||||
|
)
|
||||||
|
assert result == 2024
|
||||||
|
|
||||||
|
def test_hours_from_date_returns_zero(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("HOURS", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
|
||||||
|
)
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
def test_year_invalid_date(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("YEAR", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("not-a-date")}
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_floor(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("FLOOR", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("3.7")}) == 3
|
||||||
|
|
||||||
|
def test_floor_negative(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("FLOOR", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("-2.3")}) == -3
|
||||||
|
|
||||||
|
def test_floor_none(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("FLOOR", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("abc")}) is None
|
||||||
|
|
||||||
|
def test_ceil(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("CEIL", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("3.2")}) == 4
|
||||||
|
|
||||||
|
def test_ceil_negative(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("CEIL", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("-2.7")}) == -2
|
||||||
|
|
||||||
|
def test_abs_positive(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("ABS", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("42")}) == 42
|
||||||
|
|
||||||
|
def test_abs_negative(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("ABS", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("-42")}) == 42
|
||||||
|
|
||||||
|
def test_abs_none(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("ABS", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("abc")}) is None
|
||||||
|
|
||||||
|
def test_replace_simple(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("REPLACE",
|
||||||
|
arg=Variable("x"),
|
||||||
|
pattern=Literal(" BC"),
|
||||||
|
replacement=Literal(""),
|
||||||
|
flags=None)
|
||||||
|
result = evaluate_expression(expr, {"x": lit("500 BC")})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.value == "500"
|
||||||
|
|
||||||
|
def test_replace_regex(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("REPLACE",
|
||||||
|
arg=Variable("x"),
|
||||||
|
pattern=Literal("[0-9]+"),
|
||||||
|
replacement=Literal("X"),
|
||||||
|
flags=None)
|
||||||
|
result = evaluate_expression(expr, {"x": lit("abc123def456")})
|
||||||
|
assert result.value == "abcXdefX"
|
||||||
|
|
||||||
|
def test_replace_case_insensitive(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("REPLACE",
|
||||||
|
arg=Variable("x"),
|
||||||
|
pattern=Literal("hello"),
|
||||||
|
replacement=Literal("world"),
|
||||||
|
flags=Literal("i"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("HELLO there")})
|
||||||
|
assert result.value == "world there"
|
||||||
|
|
||||||
|
def test_round_up(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("ROUND", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("3.7")}) == 4
|
||||||
|
|
||||||
|
def test_round_down(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("ROUND", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("3.2")}) == 3
|
||||||
|
|
||||||
|
def test_round_none(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("ROUND", arg=Variable("x"))
|
||||||
|
assert evaluate_expression(expr, {"x": lit("abc")}) is None
|
||||||
|
|
||||||
|
def test_strbefore(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("STRBEFORE",
|
||||||
|
arg1=Variable("x"), arg2=Literal("-"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||||
|
assert result.value == "2024"
|
||||||
|
|
||||||
|
def test_strbefore_not_found(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("STRBEFORE",
|
||||||
|
arg1=Variable("x"), arg2=Literal("/"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||||
|
assert result.value == ""
|
||||||
|
|
||||||
|
def test_strafter(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("STRAFTER",
|
||||||
|
arg1=Variable("x"), arg2=Literal("-"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
|
||||||
|
assert result.value == "03-15"
|
||||||
|
|
||||||
|
def test_strafter_not_found(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("STRAFTER",
|
||||||
|
arg1=Variable("x"), arg2=Literal("/"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||||
|
assert result.value == ""
|
||||||
|
|
||||||
|
def test_encode_for_uri(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("hello world")})
|
||||||
|
assert result.value == "hello%20world"
|
||||||
|
|
||||||
|
def test_encode_for_uri_special_chars(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("a/b?c=d&e")})
|
||||||
|
assert result.value == "a%2Fb%3Fc%3Dd%26e"
|
||||||
|
|
||||||
|
def test_langmatches_basic(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("LANGMATCHES",
|
||||||
|
arg1=Literal("en"), arg2=Literal("en"))
|
||||||
|
assert evaluate_expression(expr, {}) is True
|
||||||
|
|
||||||
|
def test_langmatches_subtag(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("LANGMATCHES",
|
||||||
|
arg1=Literal("en-US"), arg2=Literal("en"))
|
||||||
|
assert evaluate_expression(expr, {}) is True
|
||||||
|
|
||||||
|
def test_langmatches_wildcard(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("LANGMATCHES",
|
||||||
|
arg1=Literal("fr"), arg2=Literal("*"))
|
||||||
|
assert evaluate_expression(expr, {}) is True
|
||||||
|
|
||||||
|
def test_langmatches_wildcard_empty(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("LANGMATCHES",
|
||||||
|
arg1=Literal(""), arg2=Literal("*"))
|
||||||
|
assert evaluate_expression(expr, {}) is False
|
||||||
|
|
||||||
|
def test_langmatches_no_match(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("LANGMATCHES",
|
||||||
|
arg1=Literal("fr"), arg2=Literal("en"))
|
||||||
|
assert evaluate_expression(expr, {}) is False
|
||||||
|
|
||||||
|
def test_iri_constructor(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("IRI", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("http://example.com/test")}
|
||||||
|
)
|
||||||
|
assert result.type == IRI
|
||||||
|
assert result.iri == "http://example.com/test"
|
||||||
|
|
||||||
|
def test_uri_constructor(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("URI", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("http://example.com/test")}
|
||||||
|
)
|
||||||
|
assert result.type == IRI
|
||||||
|
assert result.iri == "http://example.com/test"
|
||||||
|
|
||||||
|
def test_bnode_no_arg(self):
|
||||||
|
expr = self._make_builtin("BNODE")
|
||||||
|
result = evaluate_expression(expr, {})
|
||||||
|
assert result.type == BLANK
|
||||||
|
assert len(result.id) > 0
|
||||||
|
|
||||||
|
def test_bnode_with_label(self):
|
||||||
|
from rdflib import Literal
|
||||||
|
expr = self._make_builtin("BNODE", arg=Literal("mynode"))
|
||||||
|
result = evaluate_expression(expr, {})
|
||||||
|
assert result.type == BLANK
|
||||||
|
assert result.id == "mynode"
|
||||||
|
|
||||||
|
def test_now(self):
|
||||||
|
import re as re_mod
|
||||||
|
expr = self._make_builtin("NOW")
|
||||||
|
result = evaluate_expression(expr, {})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.datatype == XSD + "dateTime"
|
||||||
|
assert re_mod.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", result.value)
|
||||||
|
|
||||||
|
def test_tz_with_utc(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("TZ", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15T10:30:45+0000",
|
||||||
|
datatype=XSD + "dateTime")}
|
||||||
|
)
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.value == "+00:00"
|
||||||
|
|
||||||
|
def test_tz_no_timezone(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("TZ", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(
|
||||||
|
expr, {"x": lit("2024-03-15T10:30:45",
|
||||||
|
datatype=XSD + "dateTime")}
|
||||||
|
)
|
||||||
|
assert result.value == ""
|
||||||
|
|
||||||
|
def test_rand(self):
|
||||||
|
expr = self._make_builtin("RAND")
|
||||||
|
result = evaluate_expression(expr, {})
|
||||||
|
assert isinstance(result, float)
|
||||||
|
assert 0.0 <= result < 1.0
|
||||||
|
|
||||||
|
def test_uuid(self):
|
||||||
|
import re as re_mod
|
||||||
|
expr = self._make_builtin("UUID")
|
||||||
|
result = evaluate_expression(expr, {})
|
||||||
|
assert result.type == IRI
|
||||||
|
assert result.iri.startswith("urn:uuid:")
|
||||||
|
uuid_part = result.iri[len("urn:uuid:"):]
|
||||||
|
assert re_mod.match(
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
|
||||||
|
uuid_part
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_struuid(self):
|
||||||
|
import re as re_mod
|
||||||
|
expr = self._make_builtin("STRUUID")
|
||||||
|
result = evaluate_expression(expr, {})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert re_mod.match(
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
|
||||||
|
result.value
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_md5(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("MD5", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.value == "5d41402abc4b2a76b9719d911017c592"
|
||||||
|
|
||||||
|
def test_sha1(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("SHA1", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.value == "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d"
|
||||||
|
|
||||||
|
def test_sha256(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("SHA256", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert result.value == (
|
||||||
|
"2cf24dba5fb0a30e26e83b2ac5b9e29e"
|
||||||
|
"1b161e5c1fa7425e73043362938b9824"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sha512(self):
|
||||||
|
from rdflib.term import Variable
|
||||||
|
expr = self._make_builtin("SHA512", arg=Variable("x"))
|
||||||
|
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||||
|
assert result.type == LITERAL
|
||||||
|
assert len(result.value) == 128
|
||||||
|
|
||||||
|
def test_exists_with_callback(self):
|
||||||
|
from rdflib.plugins.sparql.parserutils import CompValue
|
||||||
|
graph = CompValue("BGP")
|
||||||
|
expr = self._make_builtin("EXISTS", graph=graph)
|
||||||
|
cb = lambda g, s: True
|
||||||
|
result = evaluate_expression(expr, {}, exists_cb=cb)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_exists_callback_false(self):
|
||||||
|
from rdflib.plugins.sparql.parserutils import CompValue
|
||||||
|
graph = CompValue("BGP")
|
||||||
|
expr = self._make_builtin("EXISTS", graph=graph)
|
||||||
|
cb = lambda g, s: False
|
||||||
|
result = evaluate_expression(expr, {}, exists_cb=cb)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_notexists_with_callback(self):
|
||||||
|
from rdflib.plugins.sparql.parserutils import CompValue
|
||||||
|
graph = CompValue("BGP")
|
||||||
|
expr = self._make_builtin("NOTEXISTS", graph=graph)
|
||||||
|
cb = lambda g, s: True
|
||||||
|
result = evaluate_expression(expr, {}, exists_cb=cb)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_notexists_callback_false(self):
|
||||||
|
from rdflib.plugins.sparql.parserutils import CompValue
|
||||||
|
graph = CompValue("BGP")
|
||||||
|
expr = self._make_builtin("NOTEXISTS", graph=graph)
|
||||||
|
cb = lambda g, s: False
|
||||||
|
result = evaluate_expression(expr, {}, exists_cb=cb)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
class TestEffectiveBoolean:
|
class TestEffectiveBoolean:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ Tests for SPARQL solution sequence operations.
|
||||||
import pytest
|
import pytest
|
||||||
from trustgraph.schema import Term, IRI, LITERAL
|
from trustgraph.schema import Term, IRI, LITERAL
|
||||||
from trustgraph.query.sparql.solutions import (
|
from trustgraph.query.sparql.solutions import (
|
||||||
hash_join, left_join, union, project, distinct,
|
hash_join, left_join, minus, union, project, distinct,
|
||||||
order_by, slice_solutions, _terms_equal, _compatible,
|
order_by, slice_solutions, _terms_equal, _compatible,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -311,6 +311,30 @@ class TestOrderBy:
|
||||||
result = order_by(solutions, [])
|
result = order_by(solutions, [])
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_order_by_numeric_literals(self):
|
||||||
|
solutions = [
|
||||||
|
{"year": lit("1950")},
|
||||||
|
{"year": lit("700")},
|
||||||
|
{"year": lit("2000")},
|
||||||
|
{"year": lit("450")},
|
||||||
|
{"year": lit("1200")},
|
||||||
|
]
|
||||||
|
key_fns = [(lambda sol: sol.get("year"), True)]
|
||||||
|
result = order_by(solutions, key_fns)
|
||||||
|
values = [s["year"].value for s in result]
|
||||||
|
assert values == ["450", "700", "1200", "1950", "2000"]
|
||||||
|
|
||||||
|
def test_order_by_numeric_descending(self):
|
||||||
|
solutions = [
|
||||||
|
{"year": lit("1950")},
|
||||||
|
{"year": lit("700")},
|
||||||
|
{"year": lit("2000")},
|
||||||
|
]
|
||||||
|
key_fns = [(lambda sol: sol.get("year"), False)]
|
||||||
|
result = order_by(solutions, key_fns)
|
||||||
|
values = [s["year"].value for s in result]
|
||||||
|
assert values == ["2000", "1950", "700"]
|
||||||
|
|
||||||
|
|
||||||
class TestSlice:
|
class TestSlice:
|
||||||
|
|
||||||
|
|
@ -343,3 +367,37 @@ class TestSlice:
|
||||||
solutions = [{"s": alice}, {"s": bob}]
|
solutions = [{"s": alice}, {"s": bob}]
|
||||||
result = slice_solutions(solutions)
|
result = slice_solutions(solutions)
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMinus:
|
||||||
|
|
||||||
|
def test_removes_compatible(self, alice, bob):
|
||||||
|
left = [{"s": alice}, {"s": bob}]
|
||||||
|
right = [{"s": alice}]
|
||||||
|
result = minus(left, right)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["s"].iri == "http://example.com/bob"
|
||||||
|
|
||||||
|
def test_empty_right_preserves_all(self, alice, bob):
|
||||||
|
left = [{"s": alice}, {"s": bob}]
|
||||||
|
result = minus(left, [])
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_no_shared_variables_preserves_all(self, alice, bob):
|
||||||
|
left = [{"s": alice}]
|
||||||
|
right = [{"t": bob}]
|
||||||
|
result = minus(left, right)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_all_removed(self, alice):
|
||||||
|
left = [{"s": alice}]
|
||||||
|
right = [{"s": alice}]
|
||||||
|
result = minus(left, right)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_partial_shared_variables(self, alice, bob):
|
||||||
|
left = [{"s": alice, "p": lit("x")}, {"s": bob, "p": lit("y")}]
|
||||||
|
right = [{"s": alice}]
|
||||||
|
result = minus(left, right)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["s"].iri == "http://example.com/bob"
|
||||||
|
|
|
||||||
0
tests/unit/test_rev_gateway/__init__.py
Normal file
0
tests/unit/test_rev_gateway/__init__.py
Normal file
|
|
@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, AsyncMock, patch
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch, ANY
|
||||||
|
|
||||||
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
|
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
|
||||||
|
|
||||||
|
|
||||||
class TestWebSocketResponder:
|
|
||||||
"""Test cases for WebSocketResponder class"""
|
|
||||||
|
|
||||||
def test_websocket_responder_initialization(self):
|
|
||||||
"""Test WebSocketResponder initialization"""
|
|
||||||
responder = WebSocketResponder()
|
|
||||||
|
|
||||||
assert responder.response is None
|
|
||||||
assert responder.completed is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_websocket_responder_send_method(self):
|
|
||||||
"""Test WebSocketResponder send method"""
|
|
||||||
responder = WebSocketResponder()
|
|
||||||
|
|
||||||
test_response = {"data": "test response"}
|
|
||||||
|
|
||||||
# Call send method
|
|
||||||
await responder.send(test_response)
|
|
||||||
|
|
||||||
# Verify response was stored
|
|
||||||
assert responder.response == test_response
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_websocket_responder_call_method(self):
|
|
||||||
"""Test WebSocketResponder __call__ method"""
|
|
||||||
responder = WebSocketResponder()
|
|
||||||
|
|
||||||
test_response = {"result": "success"}
|
|
||||||
test_completed = True
|
|
||||||
|
|
||||||
# Call the responder
|
|
||||||
await responder(test_response, test_completed)
|
|
||||||
|
|
||||||
# Verify response and completed status were set
|
|
||||||
assert responder.response == test_response
|
|
||||||
assert responder.completed == test_completed
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_websocket_responder_call_method_with_false_completion(self):
|
|
||||||
"""Test WebSocketResponder __call__ method with incomplete response"""
|
|
||||||
responder = WebSocketResponder()
|
|
||||||
|
|
||||||
test_response = {"partial": "data"}
|
|
||||||
test_completed = False
|
|
||||||
|
|
||||||
# Call the responder
|
|
||||||
await responder(test_response, test_completed)
|
|
||||||
|
|
||||||
# Verify response was set and completed is True (since send() always sets completed=True)
|
|
||||||
assert responder.response == test_response
|
|
||||||
assert responder.completed is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestMessageDispatcher:
|
class TestMessageDispatcher:
|
||||||
"""Test cases for MessageDispatcher class"""
|
"""Test cases for MessageDispatcher class"""
|
||||||
|
|
||||||
def test_message_dispatcher_initialization_with_defaults(self):
|
def test_message_dispatcher_initialization_with_defaults(self):
|
||||||
"""Test MessageDispatcher initialization with default parameters"""
|
|
||||||
dispatcher = MessageDispatcher()
|
dispatcher = MessageDispatcher()
|
||||||
|
|
||||||
assert dispatcher.max_workers == 10
|
assert dispatcher.max_workers == 10
|
||||||
assert dispatcher.semaphore._value == 10
|
assert dispatcher.semaphore._value == 10
|
||||||
assert dispatcher.active_tasks == set()
|
assert dispatcher.active_tasks == set()
|
||||||
assert dispatcher.backend is None
|
assert dispatcher.backend is None
|
||||||
|
assert dispatcher.auth is None
|
||||||
assert dispatcher.dispatcher_manager is None
|
assert dispatcher.dispatcher_manager is None
|
||||||
assert len(dispatcher.service_mapping) > 0
|
assert len(dispatcher.service_mapping) > 0
|
||||||
|
|
||||||
def test_message_dispatcher_initialization_with_custom_workers(self):
|
def test_message_dispatcher_initialization_with_custom_workers(self):
|
||||||
"""Test MessageDispatcher initialization with custom max_workers"""
|
|
||||||
dispatcher = MessageDispatcher(max_workers=5)
|
dispatcher = MessageDispatcher(max_workers=5)
|
||||||
|
|
||||||
assert dispatcher.max_workers == 5
|
assert dispatcher.max_workers == 5
|
||||||
assert dispatcher.semaphore._value == 5
|
assert dispatcher.semaphore._value == 5
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
||||||
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
|
def test_message_dispatcher_initialization_with_backend(
|
||||||
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
|
self, mock_dispatcher_manager,
|
||||||
|
):
|
||||||
mock_backend = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_config_receiver = MagicMock()
|
mock_config_receiver = MagicMock()
|
||||||
|
mock_auth = MagicMock()
|
||||||
mock_dispatcher_instance = MagicMock()
|
mock_dispatcher_instance = MagicMock()
|
||||||
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
||||||
|
|
||||||
dispatcher = MessageDispatcher(
|
dispatcher = MessageDispatcher(
|
||||||
max_workers=8,
|
max_workers=8,
|
||||||
config_receiver=mock_config_receiver,
|
config_receiver=mock_config_receiver,
|
||||||
backend=mock_backend
|
backend=mock_backend,
|
||||||
|
auth=mock_auth,
|
||||||
|
timeout=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dispatcher.max_workers == 8
|
assert dispatcher.max_workers == 8
|
||||||
assert dispatcher.backend == mock_backend
|
assert dispatcher.backend == mock_backend
|
||||||
|
assert dispatcher.auth == mock_auth
|
||||||
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
||||||
mock_dispatcher_manager.assert_called_once_with(
|
mock_dispatcher_manager.assert_called_once_with(
|
||||||
mock_backend, mock_config_receiver, prefix="rev-gateway"
|
mock_backend, mock_config_receiver,
|
||||||
|
auth=mock_auth, prefix="rev-gateway", timeout=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_message_dispatcher_service_mapping(self):
|
def test_message_dispatcher_service_mapping(self):
|
||||||
"""Test MessageDispatcher service mapping contains expected services"""
|
|
||||||
dispatcher = MessageDispatcher()
|
dispatcher = MessageDispatcher()
|
||||||
|
|
||||||
expected_services = [
|
expected_services = [
|
||||||
"text-completion", "graph-rag", "agent", "embeddings",
|
"text-completion", "graph-rag", "agent", "embeddings",
|
||||||
"graph-embeddings", "triples", "document-load", "text-load",
|
"graph-embeddings", "triples", "document-load", "text-load",
|
||||||
"flow", "knowledge", "config", "librarian", "document-rag"
|
"flow", "knowledge", "config", "librarian", "document-rag",
|
||||||
]
|
]
|
||||||
|
|
||||||
for service in expected_services:
|
for service in expected_services:
|
||||||
assert service in dispatcher.service_mapping
|
assert service in dispatcher.service_mapping
|
||||||
|
|
||||||
# Test specific mappings
|
|
||||||
assert dispatcher.service_mapping["text-completion"] == "text-completion"
|
|
||||||
assert dispatcher.service_mapping["document-load"] == "document"
|
assert dispatcher.service_mapping["document-load"] == "document"
|
||||||
assert dispatcher.service_mapping["text-load"] == "text-document"
|
assert dispatcher.service_mapping["text-load"] == "text-document"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
|
async def test_handle_message_without_dispatcher_manager(self):
|
||||||
"""Test MessageDispatcher handle_message without dispatcher manager"""
|
|
||||||
dispatcher = MessageDispatcher()
|
dispatcher = MessageDispatcher()
|
||||||
|
dispatcher.auth = MagicMock()
|
||||||
|
dispatcher.auth.authenticate = AsyncMock(
|
||||||
|
return_value=MagicMock(workspace="default")
|
||||||
|
)
|
||||||
|
|
||||||
test_message = {
|
sender = AsyncMock()
|
||||||
"id": "test-123",
|
|
||||||
"service": "test-service",
|
|
||||||
"request": {"data": "test"}
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await dispatcher.handle_message(test_message)
|
await dispatcher.handle_message(
|
||||||
|
{"id": "test-1", "service": "test", "request": {}},
|
||||||
|
sender,
|
||||||
|
)
|
||||||
|
|
||||||
assert result["id"] == "test-123"
|
sender.assert_called_once()
|
||||||
assert "error" in result["response"]
|
sent = sender.call_args[0][0]
|
||||||
assert "DispatcherManager not available" in result["response"]["error"]
|
assert sent["id"] == "test-1"
|
||||||
|
assert sent["error"]["message"] == "DispatcherManager not available"
|
||||||
|
assert sent["error"]["type"] == "error"
|
||||||
|
assert sent["complete"] is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_dispatcher_handle_message_with_exception(self):
|
async def test_handle_message_auth_failure(self):
|
||||||
"""Test MessageDispatcher handle_message with exception during processing"""
|
dispatcher = MessageDispatcher()
|
||||||
mock_dispatcher_manager = MagicMock()
|
dispatcher.auth = MagicMock()
|
||||||
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
|
dispatcher.auth.authenticate = AsyncMock(
|
||||||
|
side_effect=Exception("auth failure")
|
||||||
|
)
|
||||||
|
dispatcher.dispatcher_manager = MagicMock()
|
||||||
|
|
||||||
|
sender = AsyncMock()
|
||||||
|
|
||||||
|
await dispatcher.handle_message(
|
||||||
|
{"id": "test-2", "token": "bad", "service": "test", "request": {}},
|
||||||
|
sender,
|
||||||
|
)
|
||||||
|
|
||||||
|
sender.assert_called_once()
|
||||||
|
sent = sender.call_args[0][0]
|
||||||
|
assert sent["id"] == "test-2"
|
||||||
|
assert "auth failure" in sent["error"]["message"]
|
||||||
|
assert sent["complete"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_global_service(self):
|
||||||
|
mock_dm = MagicMock()
|
||||||
|
mock_dm.invoke_global_service = AsyncMock()
|
||||||
|
|
||||||
dispatcher = MessageDispatcher()
|
dispatcher = MessageDispatcher()
|
||||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
dispatcher.dispatcher_manager = mock_dm
|
||||||
|
dispatcher.auth = MagicMock()
|
||||||
|
dispatcher.auth.authenticate = AsyncMock(
|
||||||
|
return_value=MagicMock(workspace="ws1")
|
||||||
|
)
|
||||||
|
|
||||||
test_message = {
|
sender = AsyncMock()
|
||||||
"id": "test-456",
|
|
||||||
|
with patch(
|
||||||
|
'trustgraph.gateway.dispatch.manager.global_dispatchers',
|
||||||
|
{"text-completion": True},
|
||||||
|
):
|
||||||
|
await dispatcher.handle_message(
|
||||||
|
{
|
||||||
|
"id": "test-3",
|
||||||
|
"token": "tg_key",
|
||||||
"service": "text-completion",
|
"service": "text-completion",
|
||||||
"request": {"prompt": "test"}
|
"request": {"prompt": "hello"},
|
||||||
}
|
},
|
||||||
|
sender,
|
||||||
|
)
|
||||||
|
|
||||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
mock_dm.invoke_global_service.assert_called_once()
|
||||||
result = await dispatcher.handle_message(test_message)
|
args, kwargs = mock_dm.invoke_global_service.call_args
|
||||||
|
assert args[0] == {"prompt": "hello"}
|
||||||
assert result["id"] == "test-456"
|
assert args[2] == "text-completion"
|
||||||
assert "error" in result["response"]
|
assert kwargs["workspace"] == "ws1"
|
||||||
assert "Test error" in result["response"]["error"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_dispatcher_handle_message_global_service(self):
|
async def test_handle_message_flow_service(self):
|
||||||
"""Test MessageDispatcher handle_message with global service"""
|
mock_dm = MagicMock()
|
||||||
mock_dispatcher_manager = MagicMock()
|
mock_dm.invoke_flow_service = AsyncMock()
|
||||||
mock_dispatcher_manager.invoke_global_service = AsyncMock()
|
|
||||||
mock_responder = MagicMock()
|
|
||||||
mock_responder.completed = True
|
|
||||||
mock_responder.response = {"result": "success"}
|
|
||||||
|
|
||||||
dispatcher = MessageDispatcher()
|
dispatcher = MessageDispatcher()
|
||||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
dispatcher.dispatcher_manager = mock_dm
|
||||||
|
dispatcher.auth = MagicMock()
|
||||||
|
dispatcher.auth.authenticate = AsyncMock(
|
||||||
|
return_value=MagicMock(workspace="ws2")
|
||||||
|
)
|
||||||
|
|
||||||
test_message = {
|
sender = AsyncMock()
|
||||||
"id": "test-789",
|
|
||||||
"service": "text-completion",
|
|
||||||
"request": {"prompt": "hello"}
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
with patch(
|
||||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
|
||||||
result = await dispatcher.handle_message(test_message)
|
):
|
||||||
|
await dispatcher.handle_message(
|
||||||
assert result["id"] == "test-789"
|
{
|
||||||
assert result["response"] == {"result": "success"}
|
"id": "test-4",
|
||||||
mock_dispatcher_manager.invoke_global_service.assert_called_once()
|
"token": "tg_key",
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_message_dispatcher_handle_message_flow_service(self):
|
|
||||||
"""Test MessageDispatcher handle_message with flow service"""
|
|
||||||
mock_dispatcher_manager = MagicMock()
|
|
||||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
|
||||||
mock_responder = MagicMock()
|
|
||||||
mock_responder.completed = True
|
|
||||||
mock_responder.response = {"data": "flow_result"}
|
|
||||||
|
|
||||||
dispatcher = MessageDispatcher()
|
|
||||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
|
||||||
|
|
||||||
test_message = {
|
|
||||||
"id": "test-flow-123",
|
|
||||||
"service": "document-rag",
|
"service": "document-rag",
|
||||||
"request": {"query": "test"},
|
"request": {"query": "test"},
|
||||||
"flow": "custom-flow"
|
"flow": "my-flow",
|
||||||
}
|
},
|
||||||
|
sender,
|
||||||
|
)
|
||||||
|
|
||||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
mock_dm.invoke_flow_service.assert_called_once_with(
|
||||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
|
||||||
result = await dispatcher.handle_message(test_message)
|
|
||||||
|
|
||||||
assert result["id"] == "test-flow-123"
|
|
||||||
assert result["response"] == {"data": "flow_result"}
|
|
||||||
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
|
|
||||||
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_dispatcher_handle_message_incomplete_response(self):
|
async def test_handle_message_responder_sends_frames(self):
|
||||||
"""Test MessageDispatcher handle_message with incomplete response"""
|
mock_dm = MagicMock()
|
||||||
mock_dispatcher_manager = MagicMock()
|
|
||||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
async def fake_invoke(data, responder, svc, workspace=None):
|
||||||
mock_responder = MagicMock()
|
await responder({"partial": 1}, False)
|
||||||
mock_responder.completed = False
|
await responder({"partial": 2}, True)
|
||||||
mock_responder.response = None
|
|
||||||
|
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
|
||||||
|
|
||||||
dispatcher = MessageDispatcher()
|
dispatcher = MessageDispatcher()
|
||||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
dispatcher.dispatcher_manager = mock_dm
|
||||||
|
dispatcher.auth = MagicMock()
|
||||||
|
dispatcher.auth.authenticate = AsyncMock(
|
||||||
|
return_value=MagicMock(workspace="ws1")
|
||||||
|
)
|
||||||
|
|
||||||
test_message = {
|
sender = AsyncMock()
|
||||||
"id": "test-incomplete",
|
|
||||||
"service": "agent",
|
with patch(
|
||||||
"request": {"input": "test"}
|
'trustgraph.gateway.dispatch.manager.global_dispatchers',
|
||||||
|
{"text-completion": True},
|
||||||
|
):
|
||||||
|
await dispatcher.handle_message(
|
||||||
|
{
|
||||||
|
"id": "test-5",
|
||||||
|
"token": "tg_key",
|
||||||
|
"service": "text-completion",
|
||||||
|
"request": {"prompt": "hi"},
|
||||||
|
},
|
||||||
|
sender,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sender.call_count == 2
|
||||||
|
first = sender.call_args_list[0][0][0]
|
||||||
|
second = sender.call_args_list[1][0][0]
|
||||||
|
|
||||||
|
assert first == {
|
||||||
|
"id": "test-5", "response": {"partial": 1}, "complete": False,
|
||||||
|
}
|
||||||
|
assert second == {
|
||||||
|
"id": "test-5", "response": {"partial": 2}, "complete": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
|
||||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
|
||||||
result = await dispatcher.handle_message(test_message)
|
|
||||||
|
|
||||||
assert result["id"] == "test-incomplete"
|
|
||||||
assert result["response"] == {"error": "No response received"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_dispatcher_shutdown(self):
|
async def test_handle_message_workspace_from_identity(self):
|
||||||
"""Test MessageDispatcher shutdown method"""
|
mock_dm = MagicMock()
|
||||||
import asyncio
|
mock_dm.invoke_flow_service = AsyncMock()
|
||||||
|
|
||||||
|
dispatcher = MessageDispatcher()
|
||||||
|
dispatcher.dispatcher_manager = mock_dm
|
||||||
|
dispatcher.auth = MagicMock()
|
||||||
|
dispatcher.auth.authenticate = AsyncMock(
|
||||||
|
return_value=MagicMock(workspace="derived-ws")
|
||||||
|
)
|
||||||
|
|
||||||
|
sender = AsyncMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
|
||||||
|
):
|
||||||
|
await dispatcher.handle_message(
|
||||||
|
{
|
||||||
|
"id": "test-6",
|
||||||
|
"token": "tg_key",
|
||||||
|
"service": "agent",
|
||||||
|
"request": {"question": "test"},
|
||||||
|
"flow": "default",
|
||||||
|
},
|
||||||
|
sender,
|
||||||
|
)
|
||||||
|
|
||||||
|
args = mock_dm.invoke_flow_service.call_args[0]
|
||||||
|
assert args[2] == "derived-ws"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown(self):
|
||||||
dispatcher = MessageDispatcher()
|
dispatcher = MessageDispatcher()
|
||||||
|
|
||||||
# Create actual async tasks
|
|
||||||
async def dummy_task():
|
async def dummy_task():
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
return "done"
|
|
||||||
|
|
||||||
task1 = asyncio.create_task(dummy_task())
|
task1 = asyncio.create_task(dummy_task())
|
||||||
task2 = asyncio.create_task(dummy_task())
|
task2 = asyncio.create_task(dummy_task())
|
||||||
dispatcher.active_tasks = {task1, task2}
|
dispatcher.active_tasks = {task1, task2}
|
||||||
|
|
||||||
# Call shutdown
|
|
||||||
await dispatcher.shutdown()
|
await dispatcher.shutdown()
|
||||||
|
|
||||||
# Verify tasks were completed
|
|
||||||
assert task1.done()
|
assert task1.done()
|
||||||
assert task2.done()
|
assert task2.done()
|
||||||
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_dispatcher_shutdown_with_no_tasks(self):
|
async def test_shutdown_with_no_tasks(self):
|
||||||
"""Test MessageDispatcher shutdown with no active tasks"""
|
|
||||||
dispatcher = MessageDispatcher()
|
dispatcher = MessageDispatcher()
|
||||||
|
|
||||||
# Call shutdown with no active tasks
|
|
||||||
await dispatcher.shutdown()
|
await dispatcher.shutdown()
|
||||||
|
|
||||||
# Should complete without error
|
|
||||||
assert dispatcher.active_tasks == set()
|
assert dispatcher.active_tasks == set()
|
||||||
|
|
@ -8,21 +8,37 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock
|
||||||
from aiohttp import WSMsgType, ClientWebSocketResponse
|
from aiohttp import WSMsgType, ClientWebSocketResponse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run
|
from trustgraph.rev_gateway.service import ReverseGateway, run
|
||||||
|
|
||||||
|
|
||||||
|
MOCK_PATCHES = [
|
||||||
|
'trustgraph.rev_gateway.service.IamAuth',
|
||||||
|
'trustgraph.rev_gateway.service.ConfigReceiver',
|
||||||
|
'trustgraph.rev_gateway.service.MessageDispatcher',
|
||||||
|
'trustgraph.rev_gateway.service.get_pubsub',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def make_gateway(**overrides):
|
||||||
|
config = {"websocket_uri": "ws://localhost:7650/out"}
|
||||||
|
config.update(overrides)
|
||||||
|
return ReverseGateway(**config)
|
||||||
|
|
||||||
|
|
||||||
class TestReverseGateway:
|
class TestReverseGateway:
|
||||||
"""Test cases for ReverseGateway class"""
|
"""Test cases for ReverseGateway class"""
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
"""Test ReverseGateway initialization with default parameters"""
|
def test_reverse_gateway_initialization_defaults(
|
||||||
mock_backend = MagicMock()
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
assert gateway.websocket_uri == "ws://localhost:7650/out"
|
assert gateway.websocket_uri == "ws://localhost:7650/out"
|
||||||
assert gateway.host == "localhost"
|
assert gateway.host == "localhost"
|
||||||
|
|
@ -33,23 +49,20 @@ class TestReverseGateway:
|
||||||
assert gateway.max_workers == 10
|
assert gateway.max_workers == 10
|
||||||
assert gateway.running is False
|
assert gateway.running is False
|
||||||
assert gateway.reconnect_delay == 3.0
|
assert gateway.reconnect_delay == 3.0
|
||||||
assert gateway.pulsar_host == "pulsar://pulsar:6650"
|
|
||||||
assert gateway.pulsar_api_key is None
|
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
"""Test ReverseGateway initialization with custom parameters"""
|
def test_reverse_gateway_initialization_custom_params(
|
||||||
mock_backend = MagicMock()
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway(
|
gateway = make_gateway(
|
||||||
websocket_uri="wss://example.com:8080/websocket",
|
websocket_uri="wss://example.com:8080/websocket",
|
||||||
max_workers=20,
|
max_workers=20,
|
||||||
pulsar_host="pulsar://custom:6650",
|
|
||||||
pulsar_api_key="test-key",
|
|
||||||
pulsar_listener="test-listener"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
|
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
|
||||||
|
|
@ -59,75 +72,99 @@ class TestReverseGateway:
|
||||||
assert gateway.path == "/websocket"
|
assert gateway.path == "/websocket"
|
||||||
assert gateway.url == "wss://example.com:8080/websocket"
|
assert gateway.url == "wss://example.com:8080/websocket"
|
||||||
assert gateway.max_workers == 20
|
assert gateway.max_workers == 20
|
||||||
assert gateway.pulsar_host == "pulsar://custom:6650"
|
|
||||||
assert gateway.pulsar_api_key == "test-key"
|
|
||||||
assert gateway.pulsar_listener == "test-listener"
|
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
"""Test ReverseGateway initialization with WebSocket URI missing path"""
|
def test_reverse_gateway_initialization_with_missing_path(
|
||||||
mock_backend = MagicMock()
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway(websocket_uri="ws://example.com")
|
gateway = make_gateway(websocket_uri="ws://example.com")
|
||||||
|
|
||||||
assert gateway.path == "/ws"
|
assert gateway.path == "/ws"
|
||||||
assert gateway.url == "ws://example.com/ws"
|
assert gateway.url == "ws://example.com/ws"
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
|
def test_reverse_gateway_initialization_invalid_scheme(
|
||||||
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
||||||
ReverseGateway(websocket_uri="http://example.com")
|
make_gateway(websocket_uri="http://example.com")
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
"""Test ReverseGateway initialization with missing hostname"""
|
def test_reverse_gateway_initialization_missing_hostname(
|
||||||
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
||||||
ReverseGateway(websocket_uri="ws://")
|
make_gateway(websocket_uri="ws://")
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
"""Test ReverseGateway creates backend with authentication"""
|
def test_reverse_gateway_iam_auth_created(
|
||||||
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
mock_backend = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_get_pubsub.return_value = mock_backend
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway(
|
gateway = make_gateway(id="test-rev-gw")
|
||||||
pulsar_api_key="test-key",
|
|
||||||
pulsar_listener="test-listener"
|
mock_iam_auth.assert_called_once_with(
|
||||||
|
backend=mock_backend,
|
||||||
|
id="test-rev-gw",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify get_pubsub was called with the correct parameters
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
mock_get_pubsub.assert_called_once_with(
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
pulsar_host="pulsar://pulsar:6650",
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
pulsar_api_key="test-key",
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
pulsar_listener="test-listener"
|
def test_reverse_gateway_config_receiver_gets_auth(
|
||||||
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
|
mock_backend = MagicMock()
|
||||||
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
mock_auth_instance = MagicMock()
|
||||||
|
mock_iam_auth.return_value = mock_auth_instance
|
||||||
|
|
||||||
|
gateway = make_gateway()
|
||||||
|
|
||||||
|
mock_config_receiver.assert_called_once_with(
|
||||||
|
mock_backend, auth=mock_auth_instance,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_connect_success(
|
||||||
"""Test ReverseGateway successful connection"""
|
self, mock_session_class, mock_get_pubsub,
|
||||||
mock_backend = MagicMock()
|
mock_dispatcher, mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_session.ws_connect.return_value = mock_ws
|
mock_session.ws_connect.return_value = mock_ws
|
||||||
mock_session_class.return_value = mock_session
|
mock_session_class.return_value = mock_session
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
result = await gateway.connect()
|
result = await gateway.connect()
|
||||||
|
|
||||||
|
|
@ -136,38 +173,41 @@ class TestReverseGateway:
|
||||||
assert gateway.ws == mock_ws
|
assert gateway.ws == mock_ws
|
||||||
mock_session.ws_connect.assert_called_once_with(gateway.url)
|
mock_session.ws_connect.assert_called_once_with(gateway.url)
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_connect_failure(
|
||||||
"""Test ReverseGateway connection failure"""
|
self, mock_session_class, mock_get_pubsub,
|
||||||
mock_backend = MagicMock()
|
mock_dispatcher, mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
||||||
mock_session_class.return_value = mock_session
|
mock_session_class.return_value = mock_session
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
result = await gateway.connect()
|
result = await gateway.connect()
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_disconnect(
|
||||||
"""Test ReverseGateway disconnect"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
# Mock websocket and session
|
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.closed = False
|
mock_ws.closed = False
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
@ -183,18 +223,19 @@ class TestReverseGateway:
|
||||||
assert gateway.ws is None
|
assert gateway.ws is None
|
||||||
assert gateway.session is None
|
assert gateway.session is None
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_send_message(
|
||||||
"""Test ReverseGateway send message"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
# Mock websocket
|
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.closed = False
|
mock_ws.closed = False
|
||||||
gateway.ws = mock_ws
|
gateway.ws = mock_ws
|
||||||
|
|
@ -205,18 +246,19 @@ class TestReverseGateway:
|
||||||
|
|
||||||
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
|
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_send_message_closed_connection(
|
||||||
"""Test ReverseGateway send message with closed connection"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
# Mock closed websocket
|
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.closed = True
|
mock_ws.closed = True
|
||||||
gateway.ws = mock_ws
|
gateway.ws = mock_ws
|
||||||
|
|
@ -225,174 +267,165 @@ class TestReverseGateway:
|
||||||
|
|
||||||
await gateway.send_message(test_message)
|
await gateway.send_message(test_message)
|
||||||
|
|
||||||
# Should not call send_str on closed connection
|
|
||||||
mock_ws.send_str.assert_not_called()
|
mock_ws.send_str.assert_not_called()
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_handle_message(
|
||||||
"""Test ReverseGateway handle message"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
mock_dispatcher_instance = AsyncMock()
|
mock_dispatcher_instance = AsyncMock()
|
||||||
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
|
|
||||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
# Mock send_message
|
|
||||||
gateway.send_message = AsyncMock()
|
gateway.send_message = AsyncMock()
|
||||||
|
|
||||||
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
|
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
|
||||||
|
|
||||||
await gateway.handle_message(test_message)
|
await gateway.handle_message(test_message)
|
||||||
|
|
||||||
mock_dispatcher_instance.handle_message.assert_called_once_with({
|
mock_dispatcher_instance.handle_message.assert_called_once_with(
|
||||||
|
{
|
||||||
"id": "test",
|
"id": "test",
|
||||||
"service": "test-service",
|
"service": "test-service",
|
||||||
"request": {"data": "test"}
|
"request": {"data": "test"},
|
||||||
})
|
},
|
||||||
gateway.send_message.assert_called_once_with({"response": "success"})
|
gateway.send_message,
|
||||||
|
)
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_handle_message_invalid_json(
|
||||||
"""Test ReverseGateway handle message with invalid JSON"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
# Mock send_message
|
|
||||||
gateway.send_message = AsyncMock()
|
gateway.send_message = AsyncMock()
|
||||||
|
|
||||||
test_message = 'invalid json'
|
await gateway.handle_message('invalid json')
|
||||||
|
|
||||||
# Should not raise exception
|
|
||||||
await gateway.handle_message(test_message)
|
|
||||||
|
|
||||||
# Should not call send_message due to error
|
|
||||||
gateway.send_message.assert_not_called()
|
gateway.send_message.assert_not_called()
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_listen_text_message(
|
||||||
"""Test ReverseGateway listen with text message"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
||||||
# Mock websocket
|
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.closed = False
|
mock_ws.closed = False
|
||||||
gateway.ws = mock_ws
|
gateway.ws = mock_ws
|
||||||
|
|
||||||
# Mock handle_message
|
|
||||||
gateway.handle_message = AsyncMock()
|
gateway.handle_message = AsyncMock()
|
||||||
|
|
||||||
# Mock message
|
|
||||||
mock_msg = MagicMock()
|
mock_msg = MagicMock()
|
||||||
mock_msg.type = WSMsgType.TEXT
|
mock_msg.type = WSMsgType.TEXT
|
||||||
mock_msg.data = '{"test": "message"}'
|
mock_msg.data = '{"test": "message"}'
|
||||||
|
|
||||||
# Mock receive to return message once, then raise exception to stop loop
|
|
||||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||||
|
|
||||||
# listen() catches exceptions and breaks, so no exception should be raised
|
|
||||||
await gateway.listen()
|
await gateway.listen()
|
||||||
|
|
||||||
gateway.handle_message.assert_called_once_with('{"test": "message"}')
|
gateway.handle_message.assert_called_once_with('{"test": "message"}')
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_listen_binary_message(
|
||||||
"""Test ReverseGateway listen with binary message"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
||||||
# Mock websocket
|
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.closed = False
|
mock_ws.closed = False
|
||||||
gateway.ws = mock_ws
|
gateway.ws = mock_ws
|
||||||
|
|
||||||
# Mock handle_message
|
|
||||||
gateway.handle_message = AsyncMock()
|
gateway.handle_message = AsyncMock()
|
||||||
|
|
||||||
# Mock message
|
|
||||||
mock_msg = MagicMock()
|
mock_msg = MagicMock()
|
||||||
mock_msg.type = WSMsgType.BINARY
|
mock_msg.type = WSMsgType.BINARY
|
||||||
mock_msg.data = b'{"test": "binary"}'
|
mock_msg.data = b'{"test": "binary"}'
|
||||||
|
|
||||||
# Mock receive to return message once, then raise exception to stop loop
|
|
||||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||||
|
|
||||||
# listen() catches exceptions and breaks, so no exception should be raised
|
|
||||||
await gateway.listen()
|
await gateway.listen()
|
||||||
|
|
||||||
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
|
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_listen_close_message(
|
||||||
"""Test ReverseGateway listen with close message"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
||||||
# Mock websocket
|
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.closed = False
|
mock_ws.closed = False
|
||||||
gateway.ws = mock_ws
|
gateway.ws = mock_ws
|
||||||
|
|
||||||
# Mock handle_message
|
|
||||||
gateway.handle_message = AsyncMock()
|
gateway.handle_message = AsyncMock()
|
||||||
|
|
||||||
# Mock message
|
|
||||||
mock_msg = MagicMock()
|
mock_msg = MagicMock()
|
||||||
mock_msg.type = WSMsgType.CLOSE
|
mock_msg.type = WSMsgType.CLOSE
|
||||||
|
|
||||||
# Mock receive to return close message
|
|
||||||
mock_ws.receive.return_value = mock_msg
|
mock_ws.receive.return_value = mock_msg
|
||||||
|
|
||||||
await gateway.listen()
|
await gateway.listen()
|
||||||
|
|
||||||
# Should not call handle_message for close message
|
|
||||||
gateway.handle_message.assert_not_called()
|
gateway.handle_message.assert_not_called()
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_shutdown(
|
||||||
"""Test ReverseGateway shutdown"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
mock_backend = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_get_pubsub.return_value = mock_backend
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
mock_dispatcher_instance = AsyncMock()
|
mock_dispatcher_instance = AsyncMock()
|
||||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
||||||
# Mock disconnect
|
|
||||||
gateway.disconnect = AsyncMock()
|
gateway.disconnect = AsyncMock()
|
||||||
|
|
||||||
await gateway.shutdown()
|
await gateway.shutdown()
|
||||||
|
|
@ -402,15 +435,17 @@ class TestReverseGateway:
|
||||||
gateway.disconnect.assert_called_once()
|
gateway.disconnect.assert_called_once()
|
||||||
mock_backend.close.assert_called_once()
|
mock_backend.close.assert_called_once()
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
"""Test ReverseGateway stop"""
|
def test_reverse_gateway_stop(
|
||||||
mock_backend = MagicMock()
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
mock_config_receiver, mock_iam_auth,
|
||||||
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
||||||
gateway.stop()
|
gateway.stop()
|
||||||
|
|
@ -421,27 +456,29 @@ class TestReverseGateway:
|
||||||
class TestReverseGatewayRun:
|
class TestReverseGatewayRun:
|
||||||
"""Test cases for ReverseGateway run method"""
|
"""Test cases for ReverseGateway run method"""
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch(*MOCK_PATCHES[0:1])
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch(*MOCK_PATCHES[1:2])
|
||||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
@patch(*MOCK_PATCHES[2:3])
|
||||||
|
@patch(*MOCK_PATCHES[3:4])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_run_successful_cycle(
|
||||||
"""Test ReverseGateway run method with successful connect/listen cycle"""
|
self, mock_get_pubsub, mock_dispatcher,
|
||||||
mock_backend = MagicMock()
|
mock_config_receiver, mock_iam_auth,
|
||||||
mock_get_pubsub.return_value = mock_backend
|
):
|
||||||
|
mock_get_pubsub.return_value = MagicMock()
|
||||||
|
|
||||||
|
mock_auth_instance = AsyncMock()
|
||||||
|
mock_iam_auth.return_value = mock_auth_instance
|
||||||
|
|
||||||
mock_config_receiver_instance = AsyncMock()
|
mock_config_receiver_instance = AsyncMock()
|
||||||
mock_config_receiver.return_value = mock_config_receiver_instance
|
mock_config_receiver.return_value = mock_config_receiver_instance
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = make_gateway()
|
||||||
|
|
||||||
# Mock methods
|
|
||||||
gateway.connect = AsyncMock(return_value=True)
|
|
||||||
gateway.listen = AsyncMock()
|
gateway.listen = AsyncMock()
|
||||||
gateway.disconnect = AsyncMock()
|
gateway.disconnect = AsyncMock()
|
||||||
gateway.shutdown = AsyncMock()
|
gateway.shutdown = AsyncMock()
|
||||||
|
|
||||||
# Stop after one iteration
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
async def mock_connect():
|
async def mock_connect():
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
|
|
@ -456,86 +493,8 @@ class TestReverseGatewayRun:
|
||||||
|
|
||||||
await gateway.run()
|
await gateway.run()
|
||||||
|
|
||||||
|
mock_auth_instance.start.assert_called_once()
|
||||||
mock_config_receiver_instance.start.assert_called_once()
|
mock_config_receiver_instance.start.assert_called_once()
|
||||||
gateway.listen.assert_called_once()
|
gateway.listen.assert_called_once()
|
||||||
# disconnect is called twice: once in the main loop, once in shutdown
|
|
||||||
assert gateway.disconnect.call_count == 2
|
assert gateway.disconnect.call_count == 2
|
||||||
gateway.shutdown.assert_called_once()
|
gateway.shutdown.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestReverseGatewayArgs:
|
|
||||||
"""Test cases for argument parsing and run function"""
|
|
||||||
|
|
||||||
def test_parse_args_defaults(self):
|
|
||||||
"""Test parse_args with default values"""
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Mock sys.argv
|
|
||||||
original_argv = sys.argv
|
|
||||||
sys.argv = ['reverse-gateway']
|
|
||||||
|
|
||||||
try:
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
assert args.websocket_uri is None
|
|
||||||
assert args.max_workers == 10
|
|
||||||
assert args.pulsar_host is None
|
|
||||||
assert args.pulsar_api_key is None
|
|
||||||
assert args.pulsar_listener is None
|
|
||||||
finally:
|
|
||||||
sys.argv = original_argv
|
|
||||||
|
|
||||||
def test_parse_args_custom_values(self):
|
|
||||||
"""Test parse_args with custom values"""
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Mock sys.argv
|
|
||||||
original_argv = sys.argv
|
|
||||||
sys.argv = [
|
|
||||||
'reverse-gateway',
|
|
||||||
'--websocket-uri', 'ws://custom:8080/ws',
|
|
||||||
'--max-workers', '20',
|
|
||||||
'--pulsar-host', 'pulsar://custom:6650',
|
|
||||||
'--pulsar-api-key', 'test-key',
|
|
||||||
'--pulsar-listener', 'test-listener'
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
assert args.websocket_uri == 'ws://custom:8080/ws'
|
|
||||||
assert args.max_workers == 20
|
|
||||||
assert args.pulsar_host == 'pulsar://custom:6650'
|
|
||||||
assert args.pulsar_api_key == 'test-key'
|
|
||||||
assert args.pulsar_listener == 'test-listener'
|
|
||||||
finally:
|
|
||||||
sys.argv = original_argv
|
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ReverseGateway')
|
|
||||||
@patch('asyncio.run')
|
|
||||||
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
|
|
||||||
"""Test run function"""
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Mock sys.argv
|
|
||||||
original_argv = sys.argv
|
|
||||||
sys.argv = ['reverse-gateway', '--max-workers', '15']
|
|
||||||
|
|
||||||
try:
|
|
||||||
mock_gateway_instance = MagicMock()
|
|
||||||
mock_gateway_instance.url = "ws://localhost:7650/out"
|
|
||||||
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
|
|
||||||
mock_gateway_class.return_value = mock_gateway_instance
|
|
||||||
|
|
||||||
run()
|
|
||||||
|
|
||||||
mock_gateway_class.assert_called_once_with(
|
|
||||||
websocket_uri=None,
|
|
||||||
max_workers=15,
|
|
||||||
pulsar_host=None,
|
|
||||||
pulsar_api_key=None,
|
|
||||||
pulsar_listener=None
|
|
||||||
)
|
|
||||||
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
|
|
||||||
finally:
|
|
||||||
sys.argv = original_argv
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||||
|
|
@ -44,6 +45,60 @@ def from_value(x: Any) -> Any:
|
||||||
return Term(type=LITERAL, value=str(x))
|
return Term(type=LITERAL, value=str(x))
|
||||||
|
|
||||||
class TriplesClient(RequestResponse):
|
class TriplesClient(RequestResponse):
|
||||||
|
|
||||||
|
async def query_gen(self, s=None, p=None, o=None, limit=20,
|
||||||
|
collection="default",
|
||||||
|
batch_size=20, timeout=30, g=None):
|
||||||
|
"""Async generator yielding Triple objects as batches arrive."""
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
done = False
|
||||||
|
|
||||||
|
async def recipient(resp):
|
||||||
|
if resp.error:
|
||||||
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
|
batch = [
|
||||||
|
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
|
||||||
|
for v in resp.triples
|
||||||
|
]
|
||||||
|
await queue.put(batch)
|
||||||
|
|
||||||
|
if resp.is_final:
|
||||||
|
await queue.put(None)
|
||||||
|
|
||||||
|
return resp.is_final
|
||||||
|
|
||||||
|
# Launch the streaming request as a background task
|
||||||
|
task = asyncio.ensure_future(self.request(
|
||||||
|
TriplesQueryRequest(
|
||||||
|
s=from_value(s),
|
||||||
|
p=from_value(p),
|
||||||
|
o=from_value(o),
|
||||||
|
limit=limit,
|
||||||
|
collection=collection,
|
||||||
|
streaming=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
g=g,
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
recipient=recipient,
|
||||||
|
))
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
batch = await queue.get()
|
||||||
|
if batch is None:
|
||||||
|
break
|
||||||
|
for triple in batch:
|
||||||
|
yield triple
|
||||||
|
finally:
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
async def query(self, s=None, p=None, o=None, limit=20,
|
async def query(self, s=None, p=None, o=None, limit=20,
|
||||||
collection="default",
|
collection="default",
|
||||||
timeout=30, g=None):
|
timeout=30, g=None):
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ Provides semantic query understanding, ontology matching, and answer generation.
|
||||||
|
|
||||||
from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse
|
from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse
|
||||||
from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType
|
from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType
|
||||||
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
|
from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset
|
||||||
from .backend_router import BackendRouter, BackendType, QueryRoute
|
from .backend_router import BackendRouter, BackendType, QueryRoute
|
||||||
from .sparql_generator import SPARQLGenerator, SPARQLQuery
|
from .sparql_generator import SPARQLGenerator, SPARQLQuery
|
||||||
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
|
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
|
||||||
|
|
@ -27,7 +27,7 @@ __all__ = [
|
||||||
'QuestionType',
|
'QuestionType',
|
||||||
|
|
||||||
# Ontology matching
|
# Ontology matching
|
||||||
'OntologyMatcher',
|
'OntologyMatcherForQueries',
|
||||||
'QueryOntologySubset',
|
'QueryOntologySubset',
|
||||||
|
|
||||||
# Backend routing
|
# Backend routing
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ Provides comprehensive monitoring of system performance, query patterns, and res
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
|
@ -276,6 +277,26 @@ class MetricsCollector:
|
||||||
return f"{name}{{{label_str}}}"
|
return f"{name}{{{label_str}}}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_metric_label(metric_name: str, label: str) -> Optional[str]:
|
||||||
|
"""Extract a label value from an internal metric key."""
|
||||||
|
labels_start = metric_name.find('{')
|
||||||
|
labels_end = metric_name.find('}', labels_start + 1)
|
||||||
|
|
||||||
|
if labels_start == -1 or labels_end == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
labels = metric_name[labels_start + 1:labels_end]
|
||||||
|
label_match = re.search(
|
||||||
|
rf'(?:^|,){re.escape(label)}=(?:"([^"]*)"|([^,]*))',
|
||||||
|
labels,
|
||||||
|
)
|
||||||
|
if not label_match:
|
||||||
|
return None
|
||||||
|
|
||||||
|
quoted_value, unquoted_value = label_match.groups()
|
||||||
|
return quoted_value if quoted_value is not None else unquoted_value
|
||||||
|
|
||||||
|
|
||||||
class PerformanceMonitor:
|
class PerformanceMonitor:
|
||||||
"""Monitors system performance and component health."""
|
"""Monitors system performance and component health."""
|
||||||
|
|
||||||
|
|
@ -474,8 +495,8 @@ class PerformanceMonitor:
|
||||||
# Cache performance
|
# Cache performance
|
||||||
cache_types = set()
|
cache_types = set()
|
||||||
for metric_name in self.metrics_collector.counters.keys():
|
for metric_name in self.metrics_collector.counters.keys():
|
||||||
if 'cache_type=' in metric_name:
|
cache_type = _extract_metric_label(metric_name, 'cache_type')
|
||||||
cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0]
|
if cache_type is not None:
|
||||||
cache_types.add(cache_type)
|
cache_types.add(cache_type)
|
||||||
|
|
||||||
for cache_type in cache_types:
|
for cache_type in cache_types:
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,10 @@ import logging
|
||||||
from typing import List, Dict, Any, Set, Optional
|
from typing import List, Dict, Any, Set, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
|
from trustgraph.extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
|
||||||
from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder
|
from trustgraph.extract.kg.ontology.ontology_embedder import OntologyEmbedder
|
||||||
from ...extract.kg.ontology.text_processor import TextSegment
|
from trustgraph.extract.kg.ontology.text_processor import TextSegment
|
||||||
from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
|
from trustgraph.extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
|
||||||
from .question_analyzer import QuestionComponents, QuestionType
|
from .question_analyzer import QuestionComponents, QuestionType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,13 @@ from typing import Dict, Any, List, Optional, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from ....flow.flow_processor import FlowProcessor
|
from trustgraph.base.flow_processor import FlowProcessor
|
||||||
from ....tables.config import ConfigTableStore
|
from trustgraph.tables.config import ConfigTableStore
|
||||||
from ...extract.kg.ontology.ontology_loader import OntologyLoader
|
from trustgraph.extract.kg.ontology.ontology_loader import OntologyLoader
|
||||||
from ...extract.kg.ontology.vector_store import InMemoryVectorStore
|
from trustgraph.extract.kg.ontology.vector_store import InMemoryVectorStore
|
||||||
|
|
||||||
from .question_analyzer import QuestionAnalyzer, QuestionComponents
|
from .question_analyzer import QuestionAnalyzer, QuestionComponents
|
||||||
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
|
from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset
|
||||||
from .backend_router import BackendRouter, QueryRoute, BackendType
|
from .backend_router import BackendRouter, QueryRoute, BackendType
|
||||||
from .sparql_generator import SPARQLGenerator, SPARQLQuery
|
from .sparql_generator import SPARQLGenerator, SPARQLQuery
|
||||||
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
|
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
|
||||||
|
|
@ -105,7 +105,7 @@ class OntoRAGQueryService(FlowProcessor):
|
||||||
|
|
||||||
# Initialize ontology matcher
|
# Initialize ontology matcher
|
||||||
matcher_config = self.config.get('ontology_matcher', {})
|
matcher_config = self.config.get('ontology_matcher', {})
|
||||||
self.ontology_matcher = OntologyMatcher(
|
self.ontology_matcher = OntologyMatcherForQueries(
|
||||||
vector_store=self.vector_store,
|
vector_store=self.vector_store,
|
||||||
embedding_service=self.embedding_service,
|
embedding_service=self.embedding_service,
|
||||||
config=matcher_config
|
config=matcher_config
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
CASSANDRA_AVAILABLE = False
|
CASSANDRA_AVAILABLE = False
|
||||||
|
|
||||||
from ....tables.config import ConfigTableStore
|
from trustgraph.tables.config import ConfigTableStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,10 @@ SPARQL algebra evaluator.
|
||||||
Recursively evaluates an rdflib SPARQL algebra tree by issuing triple
|
Recursively evaluates an rdflib SPARQL algebra tree by issuing triple
|
||||||
pattern queries via TriplesClient (streaming) and performing in-memory
|
pattern queries via TriplesClient (streaming) and performing in-memory
|
||||||
joins, filters, and projections.
|
joins, filters, and projections.
|
||||||
|
|
||||||
|
Handlers are async generators that yield solutions incrementally.
|
||||||
|
Blocking operators (joins, sort, group, distinct) materialise their
|
||||||
|
upstream into a list at the boundary, then yield results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -17,7 +21,7 @@ from ... knowledge import Uri
|
||||||
from ... knowledge import Literal as KgLiteral
|
from ... knowledge import Literal as KgLiteral
|
||||||
from . parser import rdflib_term_to_term
|
from . parser import rdflib_term_to_term
|
||||||
from . solutions import (
|
from . solutions import (
|
||||||
hash_join, left_join, union, project, distinct,
|
hash_join, left_join, minus, union, project, distinct,
|
||||||
order_by, slice_solutions, _term_key,
|
order_by, slice_solutions, _term_key,
|
||||||
)
|
)
|
||||||
from . expressions import evaluate_expression, _effective_boolean
|
from . expressions import evaluate_expression, _effective_boolean
|
||||||
|
|
@ -34,56 +38,56 @@ async def evaluate(node, triples_client, collection, limit=10000):
|
||||||
"""
|
"""
|
||||||
Evaluate a SPARQL algebra node.
|
Evaluate a SPARQL algebra node.
|
||||||
|
|
||||||
Args:
|
Yields solutions (dicts mapping variable names to Term values)
|
||||||
node: rdflib CompValue algebra node
|
incrementally as an async generator.
|
||||||
triples_client: TriplesClient instance for triple pattern queries
|
|
||||||
collection: collection identifier
|
|
||||||
limit: safety limit on results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of solutions (dicts mapping variable names to Term values)
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(node, CompValue):
|
if not isinstance(node, CompValue):
|
||||||
logger.warning(f"Expected CompValue, got {type(node)}: {node}")
|
logger.warning(f"Expected CompValue, got {type(node)}: {node}")
|
||||||
return [{}]
|
yield {}
|
||||||
|
return
|
||||||
|
|
||||||
name = node.name
|
name = node.name
|
||||||
handler = _HANDLERS.get(name)
|
handler = _HANDLERS.get(name)
|
||||||
|
|
||||||
if handler is None:
|
if handler is None:
|
||||||
logger.warning(f"Unsupported algebra node: {name}")
|
logger.warning(f"Unsupported algebra node: {name}")
|
||||||
return [{}]
|
yield {}
|
||||||
|
return
|
||||||
|
|
||||||
return await handler(node, triples_client, collection, limit)
|
async for sol in handler(node, triples_client, collection, limit):
|
||||||
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
# --- Node handlers ---
|
async def materialise(node, triples_client, collection, limit=10000):
|
||||||
|
"""Collect all solutions from evaluate() into a list."""
|
||||||
|
return [sol async for sol in evaluate(node, triples_client, collection, limit)]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Node handlers (async generators) ---
|
||||||
|
|
||||||
async def _eval_select_query(node, tc, collection, limit):
|
async def _eval_select_query(node, tc, collection, limit):
|
||||||
"""Evaluate a SelectQuery node."""
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
return await evaluate(node.p, tc, collection, limit)
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
async def _eval_project(node, tc, collection, limit):
|
async def _eval_project(node, tc, collection, limit):
|
||||||
"""Evaluate a Project node (SELECT variable projection)."""
|
|
||||||
solutions = await evaluate(node.p, tc, collection, limit)
|
|
||||||
variables = [str(v) for v in node.PV]
|
variables = [str(v) for v in node.PV]
|
||||||
return project(solutions, variables)
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
|
yield {v: sol[v] for v in variables if v in sol}
|
||||||
|
|
||||||
|
|
||||||
async def _eval_bgp(node, tc, collection, limit):
|
async def _eval_bgp(node, tc, collection, limit):
|
||||||
"""
|
"""
|
||||||
Evaluate a Basic Graph Pattern.
|
Evaluate a Basic Graph Pattern.
|
||||||
|
|
||||||
Issues streaming triple pattern queries and joins results. Patterns
|
Patterns are ordered by selectivity and evaluated sequentially.
|
||||||
are ordered by selectivity (more bound terms first) and evaluated
|
For the final pattern, results stream directly from the triple store.
|
||||||
sequentially with bound-variable substitution.
|
|
||||||
"""
|
"""
|
||||||
triples = node.triples
|
triples = node.triples
|
||||||
if not triples:
|
if not triples:
|
||||||
return [{}]
|
yield {}
|
||||||
|
return
|
||||||
|
|
||||||
# Sort patterns by selectivity: more bound terms = more selective
|
|
||||||
def selectivity(pattern):
|
def selectivity(pattern):
|
||||||
return sum(1 for t in pattern if not isinstance(t, Variable))
|
return sum(1 for t in pattern if not isinstance(t, Variable))
|
||||||
|
|
||||||
|
|
@ -91,27 +95,50 @@ async def _eval_bgp(node, tc, collection, limit):
|
||||||
enumerate(triples), key=lambda x: -selectivity(x[1])
|
enumerate(triples), key=lambda x: -selectivity(x[1])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For all patterns except the last, we must materialise intermediate
|
||||||
|
# solutions because each pattern depends on bindings from prior ones.
|
||||||
|
# The last pattern streams directly.
|
||||||
solutions = [{}]
|
solutions = [{}]
|
||||||
|
|
||||||
for _, pattern in sorted_patterns:
|
for pattern_idx, (_, pattern) in enumerate(sorted_patterns):
|
||||||
s_tmpl, p_tmpl, o_tmpl = pattern
|
s_tmpl, p_tmpl, o_tmpl = pattern
|
||||||
|
is_last = (pattern_idx == len(sorted_patterns) - 1)
|
||||||
|
|
||||||
new_solutions = []
|
if is_last:
|
||||||
|
# Stream the final pattern — yield as triples arrive
|
||||||
|
count = 0
|
||||||
for sol in solutions:
|
for sol in solutions:
|
||||||
# Substitute known bindings into the pattern
|
|
||||||
s_val = _resolve_term(s_tmpl, sol)
|
s_val = _resolve_term(s_tmpl, sol)
|
||||||
p_val = _resolve_term(p_tmpl, sol)
|
p_val = _resolve_term(p_tmpl, sol)
|
||||||
o_val = _resolve_term(o_tmpl, sol)
|
o_val = _resolve_term(o_tmpl, sol)
|
||||||
|
|
||||||
# Query the triples store
|
async for triple in tc.query_gen(
|
||||||
results = await _query_pattern(
|
s=s_val, p=p_val, o=o_val,
|
||||||
tc, s_val, p_val, o_val, collection, limit
|
limit=limit, collection=collection,
|
||||||
)
|
):
|
||||||
|
binding = dict(sol)
|
||||||
|
if isinstance(s_tmpl, Variable):
|
||||||
|
binding[str(s_tmpl)] = _to_term(triple.s)
|
||||||
|
if isinstance(p_tmpl, Variable):
|
||||||
|
binding[str(p_tmpl)] = _to_term(triple.p)
|
||||||
|
if isinstance(o_tmpl, Variable):
|
||||||
|
binding[str(o_tmpl)] = _to_term(triple.o)
|
||||||
|
yield binding
|
||||||
|
count += 1
|
||||||
|
if count >= limit:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Materialise intermediate patterns
|
||||||
|
new_solutions = []
|
||||||
|
for sol in solutions:
|
||||||
|
s_val = _resolve_term(s_tmpl, sol)
|
||||||
|
p_val = _resolve_term(p_tmpl, sol)
|
||||||
|
o_val = _resolve_term(o_tmpl, sol)
|
||||||
|
|
||||||
# Map results back to variable bindings,
|
async for triple in tc.query_gen(
|
||||||
# converting Uri/Literal to Term objects
|
s=s_val, p=p_val, o=o_val,
|
||||||
for triple in results:
|
limit=limit, collection=collection,
|
||||||
|
):
|
||||||
binding = dict(sol)
|
binding = dict(sol)
|
||||||
if isinstance(s_tmpl, Variable):
|
if isinstance(s_tmpl, Variable):
|
||||||
binding[str(s_tmpl)] = _to_term(triple.s)
|
binding[str(s_tmpl)] = _to_term(triple.s)
|
||||||
|
|
@ -122,24 +149,168 @@ async def _eval_bgp(node, tc, collection, limit):
|
||||||
new_solutions.append(binding)
|
new_solutions.append(binding)
|
||||||
|
|
||||||
solutions = new_solutions
|
solutions = new_solutions
|
||||||
|
|
||||||
if not solutions:
|
if not solutions:
|
||||||
break
|
return
|
||||||
|
|
||||||
return solutions[:limit]
|
|
||||||
|
# --- Blocking operators: materialise upstream, then yield ---
|
||||||
|
|
||||||
|
def _is_small_node(node):
|
||||||
|
"""Check if a node is likely to produce a small number of solutions."""
|
||||||
|
if not isinstance(node, CompValue):
|
||||||
|
return False
|
||||||
|
if node.name in ("values", "ToMultiSet"):
|
||||||
|
return True
|
||||||
|
if node.name == "Extend" and hasattr(node, "p"):
|
||||||
|
return _is_small_node(node.p)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _eval_join(node, tc, collection, limit):
|
async def _eval_join(node, tc, collection, limit):
|
||||||
"""Evaluate a Join node."""
|
# Bind join: if one side is small (e.g. VALUES), materialise it and
|
||||||
left = await evaluate(node.p1, tc, collection, limit)
|
# substitute its bindings into the other side's evaluation. This
|
||||||
right = await evaluate(node.p2, tc, collection, limit)
|
# turns wildcard BGP queries into selective ones.
|
||||||
return hash_join(left, right)[:limit]
|
if _is_small_node(node.p1):
|
||||||
|
yield_from = _bind_join(node.p1, node.p2, tc, collection, limit)
|
||||||
|
elif _is_small_node(node.p2):
|
||||||
|
yield_from = _bind_join(node.p2, node.p1, tc, collection, limit)
|
||||||
|
else:
|
||||||
|
yield_from = _hash_join(node, tc, collection, limit)
|
||||||
|
|
||||||
|
async for sol in yield_from:
|
||||||
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
|
async def _hash_join(node, tc, collection, limit):
|
||||||
|
left = await materialise(node.p1, tc, collection, limit)
|
||||||
|
right = await materialise(node.p2, tc, collection, limit)
|
||||||
|
for sol in hash_join(left, right)[:limit]:
|
||||||
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
|
async def _bind_join(small_node, big_node, tc, collection, limit):
|
||||||
|
"""Iterate over the small side and inject bindings into the big side."""
|
||||||
|
small_sols = await materialise(small_node, tc, collection, limit)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for binding in small_sols:
|
||||||
|
async for sol in _evaluate_with_bindings(
|
||||||
|
big_node, binding, tc, collection, limit
|
||||||
|
):
|
||||||
|
yield sol
|
||||||
|
count += 1
|
||||||
|
if count >= limit:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_compatible(left, right):
|
||||||
|
"""Merge two solutions if compatible (shared vars have equal values)."""
|
||||||
|
merged = dict(left)
|
||||||
|
for k, v in right.items():
|
||||||
|
if k in merged:
|
||||||
|
if _term_key(merged[k]) != _term_key(v):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
merged[k] = v
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
async def _evaluate_with_bindings(node, bindings, tc, collection, limit):
|
||||||
|
"""Evaluate a node with pre-seeded variable bindings.
|
||||||
|
|
||||||
|
For BGP nodes, the bindings are injected so _resolve_term sees them,
|
||||||
|
turning wildcard queries into selective ones. For other node types,
|
||||||
|
evaluate normally and merge/filter against the bindings.
|
||||||
|
"""
|
||||||
|
if isinstance(node, CompValue) and node.name == "BGP":
|
||||||
|
async for sol in _eval_bgp_with_bindings(
|
||||||
|
node, bindings, tc, collection, limit
|
||||||
|
):
|
||||||
|
yield sol
|
||||||
|
else:
|
||||||
|
async for sol in evaluate(node, tc, collection, limit):
|
||||||
|
merged = _merge_compatible(bindings, sol)
|
||||||
|
if merged is not None:
|
||||||
|
yield merged
|
||||||
|
|
||||||
|
|
||||||
|
async def _eval_bgp_with_bindings(node, bindings, tc, collection, limit):
|
||||||
|
"""Evaluate a BGP with pre-seeded bindings so variables resolve to terms."""
|
||||||
|
triples = node.triples
|
||||||
|
if not triples:
|
||||||
|
yield dict(bindings)
|
||||||
|
return
|
||||||
|
|
||||||
|
def selectivity(pattern):
|
||||||
|
score = 0
|
||||||
|
for t in pattern:
|
||||||
|
if not isinstance(t, Variable):
|
||||||
|
score += 1
|
||||||
|
elif str(t) in bindings:
|
||||||
|
score += 1
|
||||||
|
return score
|
||||||
|
|
||||||
|
sorted_patterns = sorted(
|
||||||
|
enumerate(triples), key=lambda x: -selectivity(x[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
solutions = [dict(bindings)]
|
||||||
|
|
||||||
|
for pattern_idx, (_, pattern) in enumerate(sorted_patterns):
|
||||||
|
s_tmpl, p_tmpl, o_tmpl = pattern
|
||||||
|
is_last = (pattern_idx == len(sorted_patterns) - 1)
|
||||||
|
|
||||||
|
if is_last:
|
||||||
|
count = 0
|
||||||
|
for sol in solutions:
|
||||||
|
s_val = _resolve_term(s_tmpl, sol)
|
||||||
|
p_val = _resolve_term(p_tmpl, sol)
|
||||||
|
o_val = _resolve_term(o_tmpl, sol)
|
||||||
|
|
||||||
|
async for triple in tc.query_gen(
|
||||||
|
s=s_val, p=p_val, o=o_val,
|
||||||
|
limit=limit, collection=collection,
|
||||||
|
):
|
||||||
|
binding = dict(sol)
|
||||||
|
if isinstance(s_tmpl, Variable):
|
||||||
|
binding[str(s_tmpl)] = _to_term(triple.s)
|
||||||
|
if isinstance(p_tmpl, Variable):
|
||||||
|
binding[str(p_tmpl)] = _to_term(triple.p)
|
||||||
|
if isinstance(o_tmpl, Variable):
|
||||||
|
binding[str(o_tmpl)] = _to_term(triple.o)
|
||||||
|
yield binding
|
||||||
|
count += 1
|
||||||
|
if count >= limit:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
new_solutions = []
|
||||||
|
for sol in solutions:
|
||||||
|
s_val = _resolve_term(s_tmpl, sol)
|
||||||
|
p_val = _resolve_term(p_tmpl, sol)
|
||||||
|
o_val = _resolve_term(o_tmpl, sol)
|
||||||
|
|
||||||
|
async for triple in tc.query_gen(
|
||||||
|
s=s_val, p=p_val, o=o_val,
|
||||||
|
limit=limit, collection=collection,
|
||||||
|
):
|
||||||
|
binding = dict(sol)
|
||||||
|
if isinstance(s_tmpl, Variable):
|
||||||
|
binding[str(s_tmpl)] = _to_term(triple.s)
|
||||||
|
if isinstance(p_tmpl, Variable):
|
||||||
|
binding[str(p_tmpl)] = _to_term(triple.p)
|
||||||
|
if isinstance(o_tmpl, Variable):
|
||||||
|
binding[str(o_tmpl)] = _to_term(triple.o)
|
||||||
|
new_solutions.append(binding)
|
||||||
|
|
||||||
|
solutions = new_solutions
|
||||||
|
if not solutions:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
async def _eval_left_join(node, tc, collection, limit):
|
async def _eval_left_join(node, tc, collection, limit):
|
||||||
"""Evaluate a LeftJoin node (OPTIONAL)."""
|
# Buffer right side for hash index; stream left through probe
|
||||||
left_sols = await evaluate(node.p1, tc, collection, limit)
|
left_sols = await materialise(node.p1, tc, collection, limit)
|
||||||
right_sols = await evaluate(node.p2, tc, collection, limit)
|
right_sols = await materialise(node.p2, tc, collection, limit)
|
||||||
|
|
||||||
filter_fn = None
|
filter_fn = None
|
||||||
if hasattr(node, "expr") and node.expr is not None:
|
if hasattr(node, "expr") and node.expr is not None:
|
||||||
|
|
@ -149,42 +320,35 @@ async def _eval_left_join(node, tc, collection, limit):
|
||||||
evaluate_expression(expr, sol)
|
evaluate_expression(expr, sol)
|
||||||
)
|
)
|
||||||
|
|
||||||
return left_join(left_sols, right_sols, filter_fn)[:limit]
|
for sol in left_join(left_sols, right_sols, filter_fn)[:limit]:
|
||||||
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
async def _eval_union(node, tc, collection, limit):
|
async def _eval_minus(node, tc, collection, limit):
|
||||||
"""Evaluate a Union node."""
|
left = await materialise(node.p1, tc, collection, limit)
|
||||||
left = await evaluate(node.p1, tc, collection, limit)
|
right = await materialise(node.p2, tc, collection, limit)
|
||||||
right = await evaluate(node.p2, tc, collection, limit)
|
for sol in minus(left, right):
|
||||||
return union(left, right)[:limit]
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
async def _eval_filter(node, tc, collection, limit):
|
|
||||||
"""Evaluate a Filter node."""
|
|
||||||
solutions = await evaluate(node.p, tc, collection, limit)
|
|
||||||
expr = node.expr
|
|
||||||
return [
|
|
||||||
sol for sol in solutions
|
|
||||||
if _effective_boolean(evaluate_expression(expr, sol))
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def _eval_distinct(node, tc, collection, limit):
|
async def _eval_distinct(node, tc, collection, limit):
|
||||||
"""Evaluate a Distinct node."""
|
seen = set()
|
||||||
solutions = await evaluate(node.p, tc, collection, limit)
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
return distinct(solutions)
|
key = tuple(sorted(
|
||||||
|
(k, _term_key(v)) for k, v in sol.items()
|
||||||
|
))
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
async def _eval_reduced(node, tc, collection, limit):
|
async def _eval_reduced(node, tc, collection, limit):
|
||||||
"""Evaluate a Reduced node (like Distinct but implementation-defined)."""
|
async for sol in _eval_distinct(node, tc, collection, limit):
|
||||||
# Treat same as Distinct
|
yield sol
|
||||||
solutions = await evaluate(node.p, tc, collection, limit)
|
|
||||||
return distinct(solutions)
|
|
||||||
|
|
||||||
|
|
||||||
async def _eval_order_by(node, tc, collection, limit):
|
async def _eval_order_by(node, tc, collection, limit):
|
||||||
"""Evaluate an OrderBy node."""
|
solutions = await materialise(node.p, tc, collection, limit)
|
||||||
solutions = await evaluate(node.p, tc, collection, limit)
|
|
||||||
|
|
||||||
key_fns = []
|
key_fns = []
|
||||||
for cond in node.expr:
|
for cond in node.expr:
|
||||||
|
|
@ -196,36 +360,104 @@ async def _eval_order_by(node, tc, collection, limit):
|
||||||
ascending,
|
ascending,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
# Simple variable or expression
|
|
||||||
key_fns.append((
|
key_fns.append((
|
||||||
lambda sol, e=cond: evaluate_expression(e, sol),
|
lambda sol, e=cond: evaluate_expression(e, sol),
|
||||||
True,
|
True,
|
||||||
))
|
))
|
||||||
|
|
||||||
return order_by(solutions, key_fns)
|
for sol in order_by(solutions, key_fns):
|
||||||
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
|
# --- Streamable operators ---
|
||||||
|
|
||||||
async def _eval_slice(node, tc, collection, limit):
|
async def _eval_slice(node, tc, collection, limit):
|
||||||
"""Evaluate a Slice node (LIMIT/OFFSET)."""
|
|
||||||
# Pass tighter limit downstream if possible
|
|
||||||
inner_limit = limit
|
|
||||||
if node.length is not None:
|
|
||||||
offset = node.start or 0
|
offset = node.start or 0
|
||||||
inner_limit = min(limit, offset + node.length)
|
length = node.length
|
||||||
|
skipped = 0
|
||||||
|
emitted = 0
|
||||||
|
|
||||||
solutions = await evaluate(node.p, tc, collection, inner_limit)
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
return slice_solutions(solutions, node.start or 0, node.length)
|
if skipped < offset:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
yield sol
|
||||||
|
emitted += 1
|
||||||
|
if length is not None and emitted >= length:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def _eval_union(node, tc, collection, limit):
|
||||||
|
async for sol in evaluate(node.p1, tc, collection, limit):
|
||||||
|
yield sol
|
||||||
|
async for sol in evaluate(node.p2, tc, collection, limit):
|
||||||
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_exists(graph_node, sol, tc, collection, limit):
|
||||||
|
"""Evaluate an EXISTS graph pattern against a solution."""
|
||||||
|
async for r in evaluate(graph_node, tc, collection, limit):
|
||||||
|
shared = set(sol.keys()) & set(r.keys())
|
||||||
|
if all(
|
||||||
|
_term_key(sol[v]) == _term_key(r[v])
|
||||||
|
for v in shared
|
||||||
|
if sol.get(v) is not None and r.get(v) is not None
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _pre_eval_exists(expr, sol, tc, collection, limit, cache):
|
||||||
|
"""Walk an expression tree, pre-evaluate EXISTS/NOT EXISTS, cache results."""
|
||||||
|
if not isinstance(expr, CompValue):
|
||||||
|
return
|
||||||
|
if expr.name in ("Builtin_EXISTS", "Builtin_NOTEXISTS"):
|
||||||
|
key = id(expr.graph), id(sol)
|
||||||
|
if key not in cache:
|
||||||
|
cache[key] = await _check_exists(
|
||||||
|
expr.graph, sol, tc, collection, limit
|
||||||
|
)
|
||||||
|
return
|
||||||
|
for attr in ("expr", "other", "arg", "arg1", "arg2", "arg3"):
|
||||||
|
child = getattr(expr, attr, None)
|
||||||
|
if child is None:
|
||||||
|
continue
|
||||||
|
if isinstance(child, CompValue):
|
||||||
|
await _pre_eval_exists(child, sol, tc, collection, limit, cache)
|
||||||
|
elif isinstance(child, (list, tuple)):
|
||||||
|
for item in child:
|
||||||
|
if isinstance(item, CompValue):
|
||||||
|
await _pre_eval_exists(
|
||||||
|
item, sol, tc, collection, limit, cache
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _eval_filter(node, tc, collection, limit):
|
||||||
|
expr = node.expr
|
||||||
|
exists_cache = {}
|
||||||
|
|
||||||
|
def exists_cb(graph_node, sol):
|
||||||
|
key = id(graph_node), id(sol)
|
||||||
|
return exists_cache.get(key, False)
|
||||||
|
|
||||||
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
|
await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache)
|
||||||
|
if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)):
|
||||||
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
async def _eval_extend(node, tc, collection, limit):
|
async def _eval_extend(node, tc, collection, limit):
|
||||||
"""Evaluate an Extend node (BIND)."""
|
|
||||||
solutions = await evaluate(node.p, tc, collection, limit)
|
|
||||||
var_name = str(node.var)
|
var_name = str(node.var)
|
||||||
expr = node.expr
|
expr = node.expr
|
||||||
|
exists_cache = {}
|
||||||
|
|
||||||
result = []
|
def exists_cb(graph_node, sol):
|
||||||
for sol in solutions:
|
key = id(graph_node), id(sol)
|
||||||
val = evaluate_expression(expr, sol)
|
return exists_cache.get(key, False)
|
||||||
|
|
||||||
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
|
await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache)
|
||||||
|
val = evaluate_expression(expr, sol, exists_cb=exists_cb)
|
||||||
new_sol = dict(sol)
|
new_sol = dict(sol)
|
||||||
if isinstance(val, Term):
|
if isinstance(val, Term):
|
||||||
new_sol[var_name] = val
|
new_sol[var_name] = val
|
||||||
|
|
@ -240,16 +472,14 @@ async def _eval_extend(node, tc, collection, limit):
|
||||||
)
|
)
|
||||||
elif val is not None:
|
elif val is not None:
|
||||||
new_sol[var_name] = Term(type=LITERAL, value=str(val))
|
new_sol[var_name] = Term(type=LITERAL, value=str(val))
|
||||||
result.append(new_sol)
|
yield new_sol
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
# --- Aggregation (blocking) ---
|
||||||
|
|
||||||
async def _eval_group(node, tc, collection, limit):
|
async def _eval_group(node, tc, collection, limit):
|
||||||
"""Evaluate a Group node (GROUP BY with aggregation)."""
|
solutions = await materialise(node.p, tc, collection, limit)
|
||||||
solutions = await evaluate(node.p, tc, collection, limit)
|
|
||||||
|
|
||||||
# Extract grouping expressions
|
|
||||||
group_exprs = []
|
group_exprs = []
|
||||||
if hasattr(node, "expr") and node.expr:
|
if hasattr(node, "expr") and node.expr:
|
||||||
for expr in node.expr:
|
for expr in node.expr:
|
||||||
|
|
@ -260,7 +490,6 @@ async def _eval_group(node, tc, collection, limit):
|
||||||
else:
|
else:
|
||||||
group_exprs.append((expr, None))
|
group_exprs.append((expr, None))
|
||||||
|
|
||||||
# Group solutions
|
|
||||||
groups = defaultdict(list)
|
groups = defaultdict(list)
|
||||||
for sol in solutions:
|
for sol in solutions:
|
||||||
key_parts = []
|
key_parts = []
|
||||||
|
|
@ -270,81 +499,72 @@ async def _eval_group(node, tc, collection, limit):
|
||||||
groups[tuple(key_parts)].append(sol)
|
groups[tuple(key_parts)].append(sol)
|
||||||
|
|
||||||
if not group_exprs:
|
if not group_exprs:
|
||||||
# No GROUP BY - entire result is one group
|
|
||||||
groups[()].extend(solutions)
|
groups[()].extend(solutions)
|
||||||
|
|
||||||
# Build grouped solutions (one per group)
|
|
||||||
result = []
|
|
||||||
for key, group_sols in groups.items():
|
for key, group_sols in groups.items():
|
||||||
sol = {}
|
sol = {}
|
||||||
# Include group key variables
|
|
||||||
if group_sols:
|
if group_sols:
|
||||||
for (expr, var_name), k in zip(group_exprs, key):
|
for (expr, var_name), k in zip(group_exprs, key):
|
||||||
if var_name and group_sols:
|
if var_name and group_sols:
|
||||||
sol[var_name] = evaluate_expression(expr, group_sols[0])
|
sol[var_name] = evaluate_expression(expr, group_sols[0])
|
||||||
sol["__group__"] = group_sols
|
sol["__group__"] = group_sols
|
||||||
result.append(sol)
|
yield sol
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def _eval_aggregate_join(node, tc, collection, limit):
|
async def _eval_aggregate_join(node, tc, collection, limit):
|
||||||
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
solutions = await evaluate(node.p, tc, collection, limit)
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for sol in solutions:
|
|
||||||
group = sol.get("__group__", [sol])
|
group = sol.get("__group__", [sol])
|
||||||
new_sol = {k: v for k, v in sol.items() if k != "__group__"}
|
new_sol = {k: v for k, v in sol.items() if k != "__group__"}
|
||||||
|
|
||||||
# Apply aggregate functions
|
|
||||||
if hasattr(node, "A") and node.A:
|
if hasattr(node, "A") and node.A:
|
||||||
for agg in node.A:
|
for agg in node.A:
|
||||||
var_name = str(agg.res)
|
var_name = str(agg.res)
|
||||||
agg_val = _compute_aggregate(agg, group)
|
agg_val = _compute_aggregate(agg, group)
|
||||||
new_sol[var_name] = agg_val
|
new_sol[var_name] = agg_val
|
||||||
|
|
||||||
result.append(new_sol)
|
yield new_sol
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def _eval_graph(node, tc, collection, limit):
|
async def _eval_graph(node, tc, collection, limit):
|
||||||
"""Evaluate a Graph node (GRAPH clause)."""
|
|
||||||
term = node.term
|
term = node.term
|
||||||
|
|
||||||
if isinstance(term, URIRef):
|
if isinstance(term, URIRef):
|
||||||
# GRAPH <uri> { ... } — fixed graph
|
|
||||||
# We'd need to pass graph to triples queries
|
|
||||||
# For now, evaluate inner pattern normally
|
|
||||||
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
|
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
|
||||||
return await evaluate(node.p, tc, collection, limit)
|
|
||||||
elif isinstance(term, Variable):
|
elif isinstance(term, Variable):
|
||||||
# GRAPH ?g { ... } — variable graph
|
|
||||||
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
|
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
|
||||||
return await evaluate(node.p, tc, collection, limit)
|
|
||||||
else:
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
return await evaluate(node.p, tc, collection, limit)
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
async def _eval_values(node, tc, collection, limit):
|
async def _eval_values(node, tc, collection, limit):
|
||||||
"""Evaluate a VALUES clause (inline data)."""
|
# rdflib has two representations for VALUES:
|
||||||
variables = [str(v) for v in node.var]
|
# 1. var=[Variable...], value=[[val, ...], ...] — positional
|
||||||
solutions = []
|
# 2. var=None, res=[{Variable: val, ...}, ...] — dict-based
|
||||||
|
if hasattr(node, "res") and node.res:
|
||||||
|
for row in node.res:
|
||||||
|
sol = {}
|
||||||
|
for var, val in row.items():
|
||||||
|
if val is not None and str(val) != "UNDEF":
|
||||||
|
sol[str(var)] = rdflib_term_to_term(val)
|
||||||
|
yield sol
|
||||||
|
return
|
||||||
|
|
||||||
|
if not node.var or not node.value:
|
||||||
|
yield {}
|
||||||
|
return
|
||||||
|
variables = [str(v) for v in node.var]
|
||||||
for row in node.value:
|
for row in node.value:
|
||||||
sol = {}
|
sol = {}
|
||||||
for var_name, val in zip(variables, row):
|
for var_name, val in zip(variables, row):
|
||||||
if val is not None and str(val) != "UNDEF":
|
if val is not None and str(val) != "UNDEF":
|
||||||
sol[var_name] = rdflib_term_to_term(val)
|
sol[var_name] = rdflib_term_to_term(val)
|
||||||
solutions.append(sol)
|
yield sol
|
||||||
|
|
||||||
return solutions
|
|
||||||
|
|
||||||
|
|
||||||
async def _eval_to_multiset(node, tc, collection, limit):
|
async def _eval_to_multiset(node, tc, collection, limit):
|
||||||
"""Evaluate a ToMultiSet node (subquery)."""
|
async for sol in evaluate(node.p, tc, collection, limit):
|
||||||
return await evaluate(node.p, tc, collection, limit)
|
yield sol
|
||||||
|
|
||||||
|
|
||||||
# --- Aggregate computation ---
|
# --- Aggregate computation ---
|
||||||
|
|
@ -353,7 +573,6 @@ def _compute_aggregate(agg, group):
|
||||||
"""Compute a single aggregate function over a group of solutions."""
|
"""Compute a single aggregate function over a group of solutions."""
|
||||||
agg_name = agg.name if hasattr(agg, "name") else ""
|
agg_name = agg.name if hasattr(agg, "name") else ""
|
||||||
|
|
||||||
# Get the expression to aggregate
|
|
||||||
expr = agg.vars if hasattr(agg, "vars") else None
|
expr = agg.vars if hasattr(agg, "vars") else None
|
||||||
|
|
||||||
if agg_name == "Aggregate_Count":
|
if agg_name == "Aggregate_Count":
|
||||||
|
|
@ -525,6 +744,7 @@ _HANDLERS = {
|
||||||
"Join": _eval_join,
|
"Join": _eval_join,
|
||||||
"LeftJoin": _eval_left_join,
|
"LeftJoin": _eval_left_join,
|
||||||
"Union": _eval_union,
|
"Union": _eval_union,
|
||||||
|
"Minus": _eval_minus,
|
||||||
"Filter": _eval_filter,
|
"Filter": _eval_filter,
|
||||||
"Distinct": _eval_distinct,
|
"Distinct": _eval_distinct,
|
||||||
"Reduced": _eval_reduced,
|
"Reduced": _eval_reduced,
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,15 @@ Evaluates rdflib algebra expression nodes against a solution (variable
|
||||||
binding) to produce a value or boolean result.
|
binding) to produce a value or boolean result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import math
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, date, timezone
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from rdflib.term import Variable, URIRef, Literal, BNode
|
from rdflib.term import Variable, URIRef, Literal, BNode
|
||||||
from rdflib.plugins.sparql.parserutils import CompValue
|
from rdflib.plugins.sparql.parserutils import CompValue
|
||||||
|
|
@ -17,23 +23,31 @@ from . parser import rdflib_term_to_term
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_exists_callback = None
|
||||||
|
|
||||||
|
|
||||||
class ExpressionError(Exception):
|
class ExpressionError(Exception):
|
||||||
"""Raised when a SPARQL expression cannot be evaluated."""
|
"""Raised when a SPARQL expression cannot be evaluated."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def evaluate_expression(expr, solution):
|
def evaluate_expression(expr, solution, exists_cb=None):
|
||||||
"""
|
"""
|
||||||
Evaluate a SPARQL expression against a solution binding.
|
Evaluate a SPARQL expression against a solution binding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expr: rdflib algebra expression node
|
expr: rdflib algebra expression node
|
||||||
solution: dict mapping variable names to Term values
|
solution: dict mapping variable names to Term values
|
||||||
|
exists_cb: optional callback(graph_node, solution) -> bool for
|
||||||
|
EXISTS/NOT EXISTS evaluation; provided by algebra.py
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The result value (Term, bool, number, string, or None)
|
The result value (Term, bool, number, string, or None)
|
||||||
"""
|
"""
|
||||||
|
global _exists_callback
|
||||||
|
if exists_cb is not None:
|
||||||
|
_exists_callback = exists_cb
|
||||||
|
|
||||||
if expr is None:
|
if expr is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
@ -111,6 +125,13 @@ def _evaluate_comp_value(node, solution):
|
||||||
if name == "MultiplicativeExpression":
|
if name == "MultiplicativeExpression":
|
||||||
return _eval_multiplicative(node, solution)
|
return _eval_multiplicative(node, solution)
|
||||||
|
|
||||||
|
# IN / NOT IN — must be checked before the generic Builtin_ dispatch
|
||||||
|
if name == "Builtin_IN":
|
||||||
|
return _eval_in(node, solution)
|
||||||
|
|
||||||
|
if name == "Builtin_NOTIN":
|
||||||
|
return not _eval_in(node, solution)
|
||||||
|
|
||||||
# SPARQL built-in functions
|
# SPARQL built-in functions
|
||||||
if name.startswith("Builtin_"):
|
if name.startswith("Builtin_"):
|
||||||
return _eval_builtin(name, node, solution)
|
return _eval_builtin(name, node, solution)
|
||||||
|
|
@ -119,27 +140,10 @@ def _evaluate_comp_value(node, solution):
|
||||||
if name == "Function":
|
if name == "Function":
|
||||||
return _eval_function(node, solution)
|
return _eval_function(node, solution)
|
||||||
|
|
||||||
# Exists / NotExists
|
|
||||||
if name == "Builtin_EXISTS":
|
|
||||||
# EXISTS requires graph pattern evaluation - not handled here
|
|
||||||
logger.warning("EXISTS not supported in filter expressions")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if name == "Builtin_NOTEXISTS":
|
|
||||||
logger.warning("NOT EXISTS not supported in filter expressions")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# TrueFilter (used with OPTIONAL)
|
# TrueFilter (used with OPTIONAL)
|
||||||
if name == "TrueFilter":
|
if name == "TrueFilter":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# IN / NOT IN
|
|
||||||
if name == "Builtin_IN":
|
|
||||||
return _eval_in(node, solution)
|
|
||||||
|
|
||||||
if name == "Builtin_NOTIN":
|
|
||||||
return not _eval_in(node, solution)
|
|
||||||
|
|
||||||
logger.warning(f"Unknown CompValue expression: {name}")
|
logger.warning(f"Unknown CompValue expression: {name}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -165,6 +169,22 @@ def _eval_relational(node, solution):
|
||||||
">=": operator.ge,
|
">=": operator.ge,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if str(op) == "IN":
|
||||||
|
items = node.other if isinstance(node.other, list) else [node.other]
|
||||||
|
for item in items:
|
||||||
|
other_val = evaluate_expression(item, solution)
|
||||||
|
if _comparable_value(left) == _comparable_value(other_val):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if str(op) == "NOT IN":
|
||||||
|
items = node.other if isinstance(node.other, list) else [node.other]
|
||||||
|
for item in items:
|
||||||
|
other_val = evaluate_expression(item, solution)
|
||||||
|
if _comparable_value(left) == _comparable_value(other_val):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
op_fn = ops.get(str(op))
|
op_fn = ops.get(str(op))
|
||||||
if op_fn is None:
|
if op_fn is None:
|
||||||
logger.warning(f"Unknown relational operator: {op}")
|
logger.warning(f"Unknown relational operator: {op}")
|
||||||
|
|
@ -335,6 +355,197 @@ def _eval_builtin(name, node, solution):
|
||||||
return val
|
return val
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if builtin == "YEAR":
|
||||||
|
dt = _to_datetime(evaluate_expression(node.arg, solution))
|
||||||
|
return dt.year if dt is not None else None
|
||||||
|
|
||||||
|
if builtin == "MONTH":
|
||||||
|
dt = _to_datetime(evaluate_expression(node.arg, solution))
|
||||||
|
return dt.month if dt is not None else None
|
||||||
|
|
||||||
|
if builtin == "DAY":
|
||||||
|
dt = _to_datetime(evaluate_expression(node.arg, solution))
|
||||||
|
return dt.day if dt is not None else None
|
||||||
|
|
||||||
|
if builtin == "HOURS":
|
||||||
|
dt = _to_datetime(evaluate_expression(node.arg, solution))
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
return dt.hour if isinstance(dt, datetime) else 0
|
||||||
|
|
||||||
|
if builtin == "MINUTES":
|
||||||
|
dt = _to_datetime(evaluate_expression(node.arg, solution))
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
return dt.minute if isinstance(dt, datetime) else 0
|
||||||
|
|
||||||
|
if builtin == "SECONDS":
|
||||||
|
dt = _to_datetime(evaluate_expression(node.arg, solution))
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
return dt.second if isinstance(dt, datetime) else 0
|
||||||
|
|
||||||
|
if builtin == "FLOOR":
|
||||||
|
val = _to_numeric(evaluate_expression(node.arg, solution))
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
return int(math.floor(val))
|
||||||
|
|
||||||
|
if builtin == "CEIL":
|
||||||
|
val = _to_numeric(evaluate_expression(node.arg, solution))
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
return int(math.ceil(val))
|
||||||
|
|
||||||
|
if builtin == "ABS":
|
||||||
|
val = _to_numeric(evaluate_expression(node.arg, solution))
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
return abs(val)
|
||||||
|
|
||||||
|
if builtin == "ROUND":
|
||||||
|
val = _to_numeric(evaluate_expression(node.arg, solution))
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
return round(val)
|
||||||
|
|
||||||
|
if builtin == "STRBEFORE":
|
||||||
|
string = _to_string(evaluate_expression(node.arg1, solution))
|
||||||
|
sep = _to_string(evaluate_expression(node.arg2, solution))
|
||||||
|
idx = string.find(sep)
|
||||||
|
if idx < 0:
|
||||||
|
return Term(type=LITERAL, value="")
|
||||||
|
return Term(type=LITERAL, value=string[:idx])
|
||||||
|
|
||||||
|
if builtin == "STRAFTER":
|
||||||
|
string = _to_string(evaluate_expression(node.arg1, solution))
|
||||||
|
sep = _to_string(evaluate_expression(node.arg2, solution))
|
||||||
|
idx = string.find(sep)
|
||||||
|
if idx < 0:
|
||||||
|
return Term(type=LITERAL, value="")
|
||||||
|
return Term(type=LITERAL, value=string[idx + len(sep):])
|
||||||
|
|
||||||
|
if builtin == "ENCODE_FOR_URI":
|
||||||
|
val = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
return Term(type=LITERAL, value=quote(val, safe=""))
|
||||||
|
|
||||||
|
if builtin == "REPLACE":
|
||||||
|
string = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
pattern = _to_string(evaluate_expression(node.pattern, solution))
|
||||||
|
replacement = _to_string(
|
||||||
|
evaluate_expression(node.replacement, solution)
|
||||||
|
)
|
||||||
|
flags_str = ""
|
||||||
|
if hasattr(node, "flags") and node.flags is not None:
|
||||||
|
flags_str = _to_string(evaluate_expression(node.flags, solution))
|
||||||
|
re_flags = 0
|
||||||
|
if "i" in flags_str:
|
||||||
|
re_flags |= re.IGNORECASE
|
||||||
|
if "m" in flags_str:
|
||||||
|
re_flags |= re.MULTILINE
|
||||||
|
if "s" in flags_str:
|
||||||
|
re_flags |= re.DOTALL
|
||||||
|
try:
|
||||||
|
result = re.sub(pattern, replacement, string, flags=re_flags)
|
||||||
|
return Term(type=LITERAL, value=result)
|
||||||
|
except re.error:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if builtin == "SUBSTR":
|
||||||
|
string = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
start = _to_numeric(evaluate_expression(node.start, solution))
|
||||||
|
if start is None:
|
||||||
|
return None
|
||||||
|
start_idx = max(int(start) - 1, 0)
|
||||||
|
if hasattr(node, "length") and node.length is not None:
|
||||||
|
length = _to_numeric(evaluate_expression(node.length, solution))
|
||||||
|
if length is None:
|
||||||
|
return None
|
||||||
|
return Term(
|
||||||
|
type=LITERAL, value=string[start_idx:start_idx + int(length)]
|
||||||
|
)
|
||||||
|
return Term(type=LITERAL, value=string[start_idx:])
|
||||||
|
|
||||||
|
if builtin == "EXISTS":
|
||||||
|
if _exists_callback is not None:
|
||||||
|
return _exists_callback(node.graph, solution)
|
||||||
|
logger.warning("EXISTS requires an exists_cb; not available")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if builtin == "NOTEXISTS":
|
||||||
|
if _exists_callback is not None:
|
||||||
|
return not _exists_callback(node.graph, solution)
|
||||||
|
logger.warning("NOT EXISTS requires an exists_cb; not available")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if builtin == "LANGMATCHES":
|
||||||
|
tag = _to_string(evaluate_expression(node.arg1, solution))
|
||||||
|
rng = _to_string(evaluate_expression(node.arg2, solution))
|
||||||
|
if rng == "*":
|
||||||
|
return len(tag) > 0
|
||||||
|
return tag.lower().startswith(rng.lower())
|
||||||
|
|
||||||
|
if builtin == "IRI" or builtin == "URI":
|
||||||
|
val = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
return Term(type=IRI, iri=val)
|
||||||
|
|
||||||
|
if builtin == "BNODE":
|
||||||
|
if hasattr(node, "arg") and node.arg is not None:
|
||||||
|
label = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
return Term(type=BLANK, id=label)
|
||||||
|
return Term(type=BLANK, id=str(uuid.uuid4()))
|
||||||
|
|
||||||
|
if builtin == "NOW":
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
return Term(
|
||||||
|
type=LITERAL,
|
||||||
|
value=now.strftime("%Y-%m-%dT%H:%M:%S%z"),
|
||||||
|
datatype="http://www.w3.org/2001/XMLSchema#dateTime",
|
||||||
|
)
|
||||||
|
|
||||||
|
if builtin == "TZ":
|
||||||
|
dt = _to_datetime(evaluate_expression(node.arg, solution))
|
||||||
|
if dt is None:
|
||||||
|
return Term(type=LITERAL, value="")
|
||||||
|
if dt.tzinfo is not None:
|
||||||
|
offset = dt.strftime("%z")
|
||||||
|
if offset:
|
||||||
|
return Term(type=LITERAL, value=offset[:3] + ":" + offset[3:])
|
||||||
|
return Term(type=LITERAL, value="")
|
||||||
|
|
||||||
|
if builtin == "RAND":
|
||||||
|
return random.random()
|
||||||
|
|
||||||
|
if builtin == "UUID":
|
||||||
|
return Term(type=IRI, iri="urn:uuid:" + str(uuid.uuid4()))
|
||||||
|
|
||||||
|
if builtin == "STRUUID":
|
||||||
|
return Term(type=LITERAL, value=str(uuid.uuid4()))
|
||||||
|
|
||||||
|
if builtin == "MD5":
|
||||||
|
val = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
return Term(
|
||||||
|
type=LITERAL, value=hashlib.md5(val.encode()).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
|
if builtin == "SHA1":
|
||||||
|
val = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
return Term(
|
||||||
|
type=LITERAL, value=hashlib.sha1(val.encode()).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
|
if builtin == "SHA256":
|
||||||
|
val = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
return Term(
|
||||||
|
type=LITERAL, value=hashlib.sha256(val.encode()).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
|
if builtin == "SHA512":
|
||||||
|
val = _to_string(evaluate_expression(node.arg, solution))
|
||||||
|
return Term(
|
||||||
|
type=LITERAL, value=hashlib.sha512(val.encode()).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
if builtin == "sameTerm":
|
if builtin == "sameTerm":
|
||||||
left = evaluate_expression(node.arg1, solution)
|
left = evaluate_expression(node.arg1, solution)
|
||||||
right = evaluate_expression(node.arg2, solution)
|
right = evaluate_expression(node.arg2, solution)
|
||||||
|
|
@ -454,6 +665,27 @@ def _to_numeric(val):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_datetime(val):
|
||||||
|
"""Convert a value to a date or datetime object."""
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
s = _to_string(val)
|
||||||
|
for fmt in (
|
||||||
|
"%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z",
|
||||||
|
"%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M",
|
||||||
|
"%Y-%m-%d",
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
return datetime.strptime(s, fmt)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(s)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _comparable_value(val):
|
def _comparable_value(val):
|
||||||
"""
|
"""
|
||||||
Convert a value to a form suitable for comparison.
|
Convert a value to a form suitable for comparison.
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from ... base import TriplesClientSpec
|
from ... base import TriplesClientSpec
|
||||||
|
|
||||||
from . parser import parse_sparql, ParseError
|
from . parser import parse_sparql, ParseError
|
||||||
from . algebra import evaluate, EvaluationError
|
from . algebra import evaluate, materialise, EvaluationError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -66,11 +66,10 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
logger.debug(f"Handling SPARQL query request {id}...")
|
logger.debug(f"Handling SPARQL query request {id}...")
|
||||||
|
|
||||||
response = await self.execute_sparql(request, flow)
|
if request.streaming:
|
||||||
|
await self.execute_sparql_streaming(request, flow, id)
|
||||||
if request.streaming and response.query_type == "select":
|
|
||||||
await self.send_streaming(response, flow, id, request)
|
|
||||||
else:
|
else:
|
||||||
|
response = await self.execute_sparql(request, flow)
|
||||||
await flow("response").send(
|
await flow("response").send(
|
||||||
response, properties={"id": id}
|
response, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
@ -92,37 +91,77 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
||||||
async def send_streaming(self, response, flow, id, request):
|
async def execute_sparql_streaming(self, request, flow, id):
|
||||||
"""Send SELECT results in batches."""
|
"""Execute a SPARQL query and stream results as they arrive."""
|
||||||
|
|
||||||
bindings = response.bindings
|
try:
|
||||||
|
parsed = parse_sparql(request.query)
|
||||||
|
except ParseError as e:
|
||||||
|
await flow("response").send(
|
||||||
|
SparqlQueryResponse(
|
||||||
|
error=Error(
|
||||||
|
type="sparql-parse-error",
|
||||||
|
message=str(e),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if parsed.query_type != "select":
|
||||||
|
response = await self._execute_non_select(parsed, request, flow)
|
||||||
|
await flow("response").send(response, properties={"id": id})
|
||||||
|
return
|
||||||
|
|
||||||
|
triples_client = flow("triples-request")
|
||||||
|
variables = parsed.variables
|
||||||
batch_size = request.batch_size if request.batch_size > 0 else 20
|
batch_size = request.batch_size if request.batch_size > 0 else 20
|
||||||
|
batch = []
|
||||||
|
|
||||||
for i in range(0, len(bindings), batch_size):
|
try:
|
||||||
batch = bindings[i:i + batch_size]
|
async for sol in evaluate(
|
||||||
is_final = (i + batch_size >= len(bindings))
|
parsed.algebra,
|
||||||
|
triples_client,
|
||||||
|
collection=request.collection or "default",
|
||||||
|
limit=request.limit or 10000,
|
||||||
|
):
|
||||||
|
values = [sol.get(v) for v in variables]
|
||||||
|
batch.append(SparqlBinding(values=values))
|
||||||
|
|
||||||
|
if len(batch) >= batch_size:
|
||||||
r = SparqlQueryResponse(
|
r = SparqlQueryResponse(
|
||||||
query_type=response.query_type,
|
query_type="select",
|
||||||
variables=response.variables,
|
variables=variables,
|
||||||
bindings=batch,
|
bindings=batch,
|
||||||
is_final=is_final,
|
is_final=False,
|
||||||
)
|
)
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
batch = []
|
||||||
|
|
||||||
# Handle empty results
|
except EvaluationError as e:
|
||||||
if len(bindings) == 0:
|
await flow("response").send(
|
||||||
|
SparqlQueryResponse(
|
||||||
|
error=Error(
|
||||||
|
type="sparql-evaluation-error",
|
||||||
|
message=str(e),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Final batch (may be empty for zero results)
|
||||||
r = SparqlQueryResponse(
|
r = SparqlQueryResponse(
|
||||||
query_type=response.query_type,
|
query_type="select",
|
||||||
variables=response.variables,
|
variables=variables,
|
||||||
bindings=[],
|
bindings=batch,
|
||||||
is_final=True,
|
is_final=True,
|
||||||
)
|
)
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
||||||
async def execute_sparql(self, request, flow):
|
async def execute_sparql(self, request, flow):
|
||||||
"""Parse and evaluate a SPARQL query."""
|
"""Parse and evaluate a SPARQL query (non-streaming)."""
|
||||||
|
|
||||||
# Parse the SPARQL query
|
|
||||||
try:
|
try:
|
||||||
parsed = parse_sparql(request.query)
|
parsed = parse_sparql(request.query)
|
||||||
except ParseError as e:
|
except ParseError as e:
|
||||||
|
|
@ -133,12 +172,31 @@ class Processor(FlowProcessor):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the triples client from the flow
|
if parsed.query_type == "select":
|
||||||
triples_client = flow("triples-request")
|
triples_client = flow("triples-request")
|
||||||
|
|
||||||
# Evaluate the algebra
|
|
||||||
try:
|
try:
|
||||||
solutions = await evaluate(
|
solutions = await materialise(
|
||||||
|
parsed.algebra,
|
||||||
|
triples_client,
|
||||||
|
collection=request.collection or "default",
|
||||||
|
limit=request.limit or 10000,
|
||||||
|
)
|
||||||
|
except EvaluationError as e:
|
||||||
|
return SparqlQueryResponse(
|
||||||
|
error=Error(
|
||||||
|
type="sparql-evaluation-error",
|
||||||
|
message=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return self._build_select_response(parsed, solutions)
|
||||||
|
|
||||||
|
return await self._execute_non_select(parsed, request, flow)
|
||||||
|
|
||||||
|
async def _execute_non_select(self, parsed, request, flow):
|
||||||
|
"""Execute ASK, CONSTRUCT, or DESCRIBE queries."""
|
||||||
|
triples_client = flow("triples-request")
|
||||||
|
try:
|
||||||
|
solutions = await materialise(
|
||||||
parsed.algebra,
|
parsed.algebra,
|
||||||
triples_client,
|
triples_client,
|
||||||
collection=request.collection or "default",
|
collection=request.collection or "default",
|
||||||
|
|
@ -152,10 +210,7 @@ class Processor(FlowProcessor):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build response based on query type
|
if parsed.query_type == "ask":
|
||||||
if parsed.query_type == "select":
|
|
||||||
return self._build_select_response(parsed, solutions)
|
|
||||||
elif parsed.query_type == "ask":
|
|
||||||
return self._build_ask_response(solutions)
|
return self._build_ask_response(solutions)
|
||||||
elif parsed.query_type == "construct":
|
elif parsed.query_type == "construct":
|
||||||
return self._build_construct_response(parsed, solutions)
|
return self._build_construct_response(parsed, solutions)
|
||||||
|
|
|
||||||
|
|
@ -150,6 +150,30 @@ def left_join(left, right, filter_fn=None):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def minus(left, right):
|
||||||
|
"""
|
||||||
|
MINUS operation: remove left solutions that are compatible with any
|
||||||
|
right solution sharing at least one variable.
|
||||||
|
"""
|
||||||
|
if not right:
|
||||||
|
return list(left)
|
||||||
|
|
||||||
|
right_vars = set()
|
||||||
|
for sol in right:
|
||||||
|
right_vars.update(sol.keys())
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for sol_l in left:
|
||||||
|
shared = set(sol_l.keys()) & right_vars
|
||||||
|
if not shared:
|
||||||
|
results.append(sol_l)
|
||||||
|
continue
|
||||||
|
if not any(_compatible(sol_l, sol_r) for sol_r in right):
|
||||||
|
results.append(sol_l)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def union(left, right):
|
def union(left, right):
|
||||||
"""Union two solution sequences (concatenation)."""
|
"""Union two solution sequences (concatenation)."""
|
||||||
return list(left) + list(right)
|
return list(left) + list(right)
|
||||||
|
|
@ -177,6 +201,28 @@ def distinct(solutions):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _sort_comparable(val):
|
||||||
|
"""Convert a value to a form suitable for sort ordering."""
|
||||||
|
if val is None:
|
||||||
|
return (0, "")
|
||||||
|
if isinstance(val, (int, float)):
|
||||||
|
return (2, val)
|
||||||
|
if isinstance(val, Term):
|
||||||
|
if val.type == LITERAL:
|
||||||
|
try:
|
||||||
|
if "." in val.value:
|
||||||
|
return (2, float(val.value))
|
||||||
|
return (2, int(val.value))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
return (3, val.value)
|
||||||
|
elif val.type == IRI:
|
||||||
|
return (4, val.iri)
|
||||||
|
elif val.type == BLANK:
|
||||||
|
return (5, val.id)
|
||||||
|
return (6, str(val))
|
||||||
|
|
||||||
|
|
||||||
def order_by(solutions, key_fns):
|
def order_by(solutions, key_fns):
|
||||||
"""
|
"""
|
||||||
Sort solutions by the given key functions.
|
Sort solutions by the given key functions.
|
||||||
|
|
@ -191,14 +237,7 @@ def order_by(solutions, key_fns):
|
||||||
keys = []
|
keys = []
|
||||||
for fn, ascending in key_fns:
|
for fn, ascending in key_fns:
|
||||||
val = fn(sol)
|
val = fn(sol)
|
||||||
# Convert to comparable form
|
keys.append(_sort_comparable(val))
|
||||||
if val is None:
|
|
||||||
comparable = ("", "")
|
|
||||||
elif isinstance(val, Term):
|
|
||||||
comparable = _term_key(val)
|
|
||||||
else:
|
|
||||||
comparable = ("v", str(val))
|
|
||||||
keys.append(comparable)
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
# Handle ascending/descending
|
# Handle ascending/descending
|
||||||
|
|
@ -224,10 +263,8 @@ def _mixed_sort(solutions, key_fns):
|
||||||
|
|
||||||
def compare(a, b):
|
def compare(a, b):
|
||||||
for fn, ascending in key_fns:
|
for fn, ascending in key_fns:
|
||||||
va = fn(a)
|
ka = _sort_comparable(fn(a))
|
||||||
vb = fn(b)
|
kb = _sort_comparable(fn(b))
|
||||||
ka = _term_key(va) if isinstance(va, Term) else ("v", str(va)) if va is not None else ("", "")
|
|
||||||
kb = _term_key(vb) if isinstance(vb, Term) else ("v", str(vb)) if vb is not None else ("", "")
|
|
||||||
|
|
||||||
if ka < kb:
|
if ka < kb:
|
||||||
return -1 if ascending else 1
|
return -1 if ascending else 1
|
||||||
|
|
|
||||||
|
|
@ -1,46 +1,43 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional, Callable, Awaitable
|
||||||
from trustgraph.messaging import TranslatorRegistry
|
|
||||||
from ..gateway.dispatch.manager import DispatcherManager
|
from ..gateway.dispatch.manager import DispatcherManager
|
||||||
|
|
||||||
logger = logging.getLogger("dispatcher")
|
logger = logging.getLogger("dispatcher")
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
class WebSocketResponder:
|
|
||||||
"""Simple responder that captures response for websocket return"""
|
|
||||||
def __init__(self):
|
|
||||||
self.response = None
|
|
||||||
self.completed = False
|
|
||||||
|
|
||||||
async def send(self, data):
|
class _TokenShim:
|
||||||
"""Capture the response data"""
|
def __init__(self, token):
|
||||||
self.response = data
|
self.headers = (
|
||||||
self.completed = True
|
{"Authorization": f"Bearer {token}"} if token else {}
|
||||||
|
)
|
||||||
|
|
||||||
async def __call__(self, data, final=False):
|
|
||||||
"""Make the responder callable for compatibility with requestor"""
|
|
||||||
await self.send(data)
|
|
||||||
if final:
|
|
||||||
self.completed = True
|
|
||||||
|
|
||||||
class MessageDispatcher:
|
class MessageDispatcher:
|
||||||
|
|
||||||
def __init__(self, max_workers: int = 10, config_receiver=None, backend=None):
|
def __init__(self, max_workers=10, config_receiver=None, backend=None,
|
||||||
|
auth=None, timeout=120):
|
||||||
self.max_workers = max_workers
|
self.max_workers = max_workers
|
||||||
self.semaphore = asyncio.Semaphore(max_workers)
|
self.semaphore = asyncio.Semaphore(max_workers)
|
||||||
self.active_tasks = set()
|
self.active_tasks = set()
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
self.auth = auth
|
||||||
|
|
||||||
# Use DispatcherManager for flow and service management
|
if backend and config_receiver and auth:
|
||||||
if backend and config_receiver:
|
self.dispatcher_manager = DispatcherManager(
|
||||||
self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway")
|
backend, config_receiver,
|
||||||
|
auth=auth,
|
||||||
|
prefix="rev-gateway",
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.dispatcher_manager = None
|
self.dispatcher_manager = None
|
||||||
logger.warning("No backend or config_receiver provided - using fallback mode")
|
logger.warning(
|
||||||
|
"Missing backend, config_receiver, or auth "
|
||||||
|
"— using fallback mode"
|
||||||
|
)
|
||||||
|
|
||||||
# Service name mapping from websocket protocol to translator registry
|
|
||||||
self.service_mapping = {
|
self.service_mapping = {
|
||||||
"text-completion": "text-completion",
|
"text-completion": "text-completion",
|
||||||
"graph-rag": "graph-rag",
|
"graph-rag": "graph-rag",
|
||||||
|
|
@ -54,77 +51,90 @@ class MessageDispatcher:
|
||||||
"knowledge": "knowledge",
|
"knowledge": "knowledge",
|
||||||
"config": "config",
|
"config": "config",
|
||||||
"librarian": "librarian",
|
"librarian": "librarian",
|
||||||
"document-rag": "document-rag"
|
"document-rag": "document-rag",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
async def handle_message(
|
||||||
|
self, message: Dict[Any, Any],
|
||||||
|
sender: Callable[[dict], Awaitable[None]],
|
||||||
|
):
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
task = asyncio.create_task(self._process_message(message))
|
task = asyncio.create_task(
|
||||||
|
self._process_message(message, sender)
|
||||||
|
)
|
||||||
self.active_tasks.add(task)
|
self.active_tasks.add(task)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await task
|
await task
|
||||||
return result
|
|
||||||
finally:
|
finally:
|
||||||
self.active_tasks.discard(task)
|
self.active_tasks.discard(task)
|
||||||
|
|
||||||
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
|
async def _authenticate(self, token):
|
||||||
|
if not self.auth:
|
||||||
|
raise RuntimeError("Auth not configured")
|
||||||
|
return await self.auth.authenticate(_TokenShim(token))
|
||||||
|
|
||||||
|
async def _process_message(
|
||||||
|
self, message: Dict[Any, Any],
|
||||||
|
sender: Callable[[dict], Awaitable[None]],
|
||||||
|
):
|
||||||
request_id = message.get('id', str(uuid.uuid4()))
|
request_id = message.get('id', str(uuid.uuid4()))
|
||||||
service = message.get('service')
|
service = message.get('service')
|
||||||
request_data = message.get('request', {})
|
request_data = message.get('request', {})
|
||||||
flow_id = message.get('flow', 'default') # Default flow
|
token = message.get('token', '')
|
||||||
|
flow_id = message.get('flow', 'default')
|
||||||
|
|
||||||
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
|
logger.info(
|
||||||
|
f"Processing message {request_id} for service "
|
||||||
|
f"{service} on flow {flow_id}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.dispatcher_manager:
|
if not self.dispatcher_manager:
|
||||||
raise RuntimeError("DispatcherManager not available - backend and config_receiver required")
|
raise RuntimeError(
|
||||||
|
"DispatcherManager not available"
|
||||||
|
)
|
||||||
|
|
||||||
# Use DispatcherManager for flow-based processing
|
identity = await self._authenticate(token)
|
||||||
responder = WebSocketResponder()
|
workspace = identity.workspace
|
||||||
|
|
||||||
|
async def responder(resp, fin):
|
||||||
|
await sender({
|
||||||
|
"id": request_id,
|
||||||
|
"response": resp,
|
||||||
|
"complete": fin,
|
||||||
|
})
|
||||||
|
|
||||||
# Map websocket service name to dispatcher service name
|
|
||||||
dispatcher_service = self.service_mapping.get(service, service)
|
dispatcher_service = self.service_mapping.get(service, service)
|
||||||
|
|
||||||
# Check if this is a global service or flow service
|
|
||||||
from ..gateway.dispatch.manager import global_dispatchers
|
from ..gateway.dispatch.manager import global_dispatchers
|
||||||
if dispatcher_service in global_dispatchers:
|
if dispatcher_service in global_dispatchers:
|
||||||
# Use global service dispatcher
|
|
||||||
await self.dispatcher_manager.invoke_global_service(
|
await self.dispatcher_manager.invoke_global_service(
|
||||||
request_data, responder, dispatcher_service
|
request_data, responder, dispatcher_service,
|
||||||
|
workspace=workspace,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use DispatcherManager to process the request through Pulsar queues
|
|
||||||
await self.dispatcher_manager.invoke_flow_service(
|
await self.dispatcher_manager.invoke_flow_service(
|
||||||
request_data, responder, flow_id, dispatcher_service
|
request_data, responder, workspace, flow_id,
|
||||||
|
dispatcher_service,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response from the responder
|
|
||||||
if responder.completed:
|
|
||||||
response_data = responder.response
|
|
||||||
else:
|
|
||||||
response_data = {'error': 'No response received'}
|
|
||||||
|
|
||||||
response = {
|
|
||||||
'id': request_id,
|
|
||||||
'response': response_data
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing message {request_id}: {e}")
|
logger.error(f"Error processing message {request_id}: {e}")
|
||||||
response = {
|
await sender({
|
||||||
'id': request_id,
|
"id": request_id,
|
||||||
'response': {'error': str(e)}
|
"error": {"message": str(e), "type": "error"},
|
||||||
}
|
"complete": True,
|
||||||
|
})
|
||||||
|
|
||||||
logger.info(f"Completed processing message {request_id}")
|
logger.info(f"Completed processing message {request_id}")
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
if self.active_tasks:
|
if self.active_tasks:
|
||||||
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
|
logger.info(
|
||||||
|
f"Waiting for {len(self.active_tasks)} active "
|
||||||
|
f"tasks to complete"
|
||||||
|
)
|
||||||
await asyncio.gather(*self.active_tasks, return_exceptions=True)
|
await asyncio.gather(*self.active_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# DispatcherManager handles its own cleanup
|
|
||||||
logger.info("Dispatcher shutdown complete")
|
logger.info("Dispatcher shutdown complete")
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,9 @@
|
||||||
|
"""
|
||||||
|
Reverse gateway. Initiates outbound WebSocket connections to a remote
|
||||||
|
relay and dispatches incoming requests through the same DispatcherManager
|
||||||
|
pipeline as api-gateway.
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -9,67 +15,68 @@ from typing import Optional
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
from .dispatcher import MessageDispatcher
|
from .dispatcher import MessageDispatcher
|
||||||
|
from ..gateway.auth import IamAuth
|
||||||
from ..gateway.config.receiver import ConfigReceiver
|
from ..gateway.config.receiver import ConfigReceiver
|
||||||
from ..base import get_pubsub
|
from ..base.pubsub import get_pubsub, add_pubsub_args
|
||||||
|
from ..base.logging import setup_logging, add_logging_args
|
||||||
|
|
||||||
logger = logging.getLogger("rev_gateway")
|
logger = logging.getLogger("rev_gateway")
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
default_websocket = "ws://localhost:7650/out"
|
default_websocket = "ws://localhost:7650/out"
|
||||||
|
default_timeout = 600
|
||||||
|
|
||||||
class ReverseGateway:
|
class ReverseGateway:
|
||||||
|
|
||||||
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
|
def __init__(self, **config):
|
||||||
pulsar_host: str = None, pulsar_api_key: str = None,
|
websocket_uri = config.get("websocket_uri")
|
||||||
pulsar_listener: str = None):
|
|
||||||
# Set default WebSocket URI with environment variable support
|
|
||||||
if websocket_uri is None:
|
if websocket_uri is None:
|
||||||
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
|
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
|
||||||
|
|
||||||
# Parse and validate the WebSocket URI
|
|
||||||
parsed_uri = urlparse(websocket_uri)
|
parsed_uri = urlparse(websocket_uri)
|
||||||
if parsed_uri.scheme not in ('ws', 'wss'):
|
if parsed_uri.scheme not in ('ws', 'wss'):
|
||||||
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
|
raise ValueError(
|
||||||
|
f"WebSocket URI must use ws:// or wss:// scheme, "
|
||||||
|
f"got: {parsed_uri.scheme}"
|
||||||
|
)
|
||||||
if not parsed_uri.netloc:
|
if not parsed_uri.netloc:
|
||||||
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
|
raise ValueError(
|
||||||
|
f"WebSocket URI must include hostname, "
|
||||||
|
f"got: {websocket_uri}"
|
||||||
|
)
|
||||||
|
|
||||||
# Store parsed components for debugging/logging
|
|
||||||
self.websocket_uri = websocket_uri
|
self.websocket_uri = websocket_uri
|
||||||
self.host = parsed_uri.hostname
|
self.host = parsed_uri.hostname
|
||||||
self.port = parsed_uri.port
|
self.port = parsed_uri.port
|
||||||
self.scheme = parsed_uri.scheme
|
self.scheme = parsed_uri.scheme
|
||||||
self.path = parsed_uri.path or "/ws"
|
self.path = parsed_uri.path or "/ws"
|
||||||
|
|
||||||
# Construct the full URL (in case path was missing)
|
|
||||||
if not parsed_uri.path:
|
if not parsed_uri.path:
|
||||||
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
|
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
|
||||||
else:
|
else:
|
||||||
self.url = websocket_uri
|
self.url = websocket_uri
|
||||||
|
|
||||||
self.max_workers = max_workers
|
self.max_workers = int(config.get("max_workers", 10))
|
||||||
|
self.timeout = int(config.get("timeout", default_timeout))
|
||||||
self.ws: Optional[ClientWebSocketResponse] = None
|
self.ws: Optional[ClientWebSocketResponse] = None
|
||||||
self.session: Optional[ClientSession] = None
|
self.session: Optional[ClientSession] = None
|
||||||
self.running = False
|
self.running = False
|
||||||
self.reconnect_delay = 3.0
|
self.reconnect_delay = 3.0
|
||||||
|
|
||||||
# Pulsar configuration
|
self.backend = get_pubsub(**config)
|
||||||
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
|
||||||
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
|
|
||||||
self.pulsar_listener = pulsar_listener
|
|
||||||
|
|
||||||
# Create backend using factory
|
self.auth = IamAuth(
|
||||||
backend_params = {
|
backend=self.backend,
|
||||||
'pulsar_host': self.pulsar_host,
|
id=config.get("id", "rev-gateway"),
|
||||||
'pulsar_api_key': self.pulsar_api_key,
|
)
|
||||||
'pulsar_listener': self.pulsar_listener,
|
|
||||||
}
|
|
||||||
self.backend = get_pubsub(**backend_params)
|
|
||||||
|
|
||||||
# Initialize config receiver
|
self.config_receiver = ConfigReceiver(
|
||||||
self.config_receiver = ConfigReceiver(self.backend)
|
self.backend, auth=self.auth,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize dispatcher with config_receiver and backend - must be created after config_receiver
|
self.dispatcher = MessageDispatcher(
|
||||||
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend)
|
self.max_workers, self.config_receiver, self.backend,
|
||||||
|
auth=self.auth, timeout=self.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
async def connect(self) -> bool:
|
||||||
try:
|
try:
|
||||||
|
|
@ -78,7 +85,10 @@ class ReverseGateway:
|
||||||
|
|
||||||
logger.info(f"Connecting to {self.url}")
|
logger.info(f"Connecting to {self.url}")
|
||||||
self.ws = await self.session.ws_connect(self.url)
|
self.ws = await self.session.ws_connect(self.url)
|
||||||
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
|
logger.info(
|
||||||
|
f"WebSocket connection established to "
|
||||||
|
f"{self.host}:{self.port or 'default'}"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -105,10 +115,9 @@ class ReverseGateway:
|
||||||
logger.debug(f"Received message: {message}")
|
logger.debug(f"Received message: {message}")
|
||||||
|
|
||||||
msg_data = json.loads(message)
|
msg_data = json.loads(message)
|
||||||
response = await self.dispatcher.handle_message(msg_data)
|
await self.dispatcher.handle_message(
|
||||||
|
msg_data, self.send_message,
|
||||||
if response:
|
)
|
||||||
await self.send_message(response)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling message: {e}")
|
logger.error(f"Error handling message: {e}")
|
||||||
|
|
@ -134,8 +143,7 @@ class ReverseGateway:
|
||||||
self.running = True
|
self.running = True
|
||||||
logger.info("Starting reverse gateway")
|
logger.info("Starting reverse gateway")
|
||||||
|
|
||||||
# Start config receiver
|
await self.auth.start()
|
||||||
logger.info("Starting config receiver")
|
|
||||||
await self.config_receiver.start()
|
await self.config_receiver.start()
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
|
|
@ -143,7 +151,10 @@ class ReverseGateway:
|
||||||
if await self.connect():
|
if await self.connect():
|
||||||
await self.listen()
|
await self.listen()
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
|
logger.warning(
|
||||||
|
f"Connection failed, retrying in "
|
||||||
|
f"{self.reconnect_delay} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
|
|
@ -166,67 +177,59 @@ class ReverseGateway:
|
||||||
await self.dispatcher.shutdown()
|
await self.dispatcher.shutdown()
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
# Close backend
|
|
||||||
if hasattr(self, 'backend'):
|
if hasattr(self, 'backend'):
|
||||||
self.backend.close()
|
self.backend.close()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
|
def run():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog="reverse-gateway",
|
prog="reverse-gateway",
|
||||||
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
|
description=__doc__,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--id',
|
||||||
|
default='rev-gateway',
|
||||||
|
help='Service identifier (default: rev-gateway)',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--websocket-uri',
|
'--websocket-uri',
|
||||||
default=None,
|
default=None,
|
||||||
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)'
|
help=f'WebSocket URI to connect to (default: {default_websocket})',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--max-workers',
|
'--max-workers',
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=10,
|
||||||
help='Maximum concurrent message handlers (default: 10)'
|
help='Maximum concurrent message handlers (default: 10)',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-p', '--pulsar-host',
|
'--timeout',
|
||||||
default=None,
|
type=int,
|
||||||
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
|
default=default_timeout,
|
||||||
|
help=f'Request timeout in seconds (default: {default_timeout})',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
add_pubsub_args(parser)
|
||||||
'--pulsar-api-key',
|
add_logging_args(parser)
|
||||||
default=None,
|
|
||||||
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
args = parser.parse_args()
|
||||||
'--pulsar-listener',
|
args = vars(args)
|
||||||
default=None,
|
|
||||||
help='Pulsar listener name'
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
setup_logging(args)
|
||||||
|
|
||||||
def run():
|
gateway = ReverseGateway(**args)
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
gateway = ReverseGateway(
|
|
||||||
websocket_uri=args.websocket_uri,
|
|
||||||
max_workers=args.max_workers,
|
|
||||||
pulsar_host=args.pulsar_host,
|
|
||||||
pulsar_api_key=args.pulsar_api_key,
|
|
||||||
pulsar_listener=args.pulsar_listener
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Starting reverse gateway:")
|
logger.info(f"Starting reverse gateway:")
|
||||||
logger.info(f" WebSocket URI: {gateway.url}")
|
logger.info(f" WebSocket URI: {gateway.url}")
|
||||||
logger.info(f" Max workers: {args.max_workers}")
|
logger.info(f" Max workers: {gateway.max_workers}")
|
||||||
logger.info(f" Pulsar host: {gateway.pulsar_host}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
asyncio.run(gateway.run())
|
asyncio.run(gateway.run())
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue