Merge pull request #953 from trustgraph-ai/release/v2.5

release/v2.5 -> master
This commit is contained in:
cybermaggedon 2026-05-26 15:01:44 +01:00 committed by GitHub
commit 36eadbda3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 2632 additions and 1140 deletions

View file

@ -25,7 +25,7 @@ BUCKET_URL = "https://storage.googleapis.com/trustgraph-library"
INDEX_URL = f"{BUCKET_URL}/index.json" INDEX_URL = f"{BUCKET_URL}/index.json"
default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
default_user = "trustgraph" default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -113,7 +113,7 @@ def convert_metadata(metadata_json):
return triples return triples
def load_document(api, user, doc_entry): def load_document(api, doc_entry):
"""Fetch metadata and content for a document, then load into TrustGraph.""" """Fetch metadata and content for a document, then load into TrustGraph."""
doc_id = doc_entry["id"] doc_id = doc_entry["id"]
title = doc_entry["title"] title = doc_entry["title"]
@ -133,7 +133,6 @@ def load_document(api, user, doc_entry):
api.add_document( api.add_document(
id=doc["id"], id=doc["id"],
metadata=metadata, metadata=metadata,
user=user,
kind=doc["kind"], kind=doc["kind"],
title=doc["title"], title=doc["title"],
comments=doc["comments"], comments=doc["comments"],
@ -144,12 +143,12 @@ def load_document(api, user, doc_entry):
print(f" done.") print(f" done.")
def load_documents(api, user, docs): def load_documents(api, docs):
"""Load a list of documents.""" """Load a list of documents."""
print(f"Loading {len(docs)} document(s)...\n") print(f"Loading {len(docs)} document(s)...\n")
for doc in docs: for doc in docs:
try: try:
load_document(api, user, doc) load_document(api, doc)
except Exception as e: except Exception as e:
print(f" FAILED: {e}", file=sys.stderr) print(f" FAILED: {e}", file=sys.stderr)
print() print()
@ -166,8 +165,8 @@ def main():
help=f"TrustGraph API URL (default: {default_url})", help=f"TrustGraph API URL (default: {default_url})",
) )
parser.add_argument( parser.add_argument(
"-U", "--user", default=default_user, "-w", "--workspace", default=default_workspace,
help=f"User ID (default: {default_user})", help=f"Workspace (default: {default_workspace})",
) )
parser.add_argument( parser.add_argument(
"-t", "--token", default=default_token, "-t", "--token", default=default_token,
@ -212,22 +211,22 @@ def main():
return return
# Load commands need the API # Load commands need the API
api = Api(args.url, token=args.token).library() api = Api(args.url, token=args.token, workspace=args.workspace).library()
if args.command == "load-all": if args.command == "load-all":
load_documents(api, args.user, index) load_documents(api, index)
elif args.command == "load-doc": elif args.command == "load-doc":
matches = [d for d in index if str(d.get("id")) == args.id] matches = [d for d in index if str(d.get("id")) == args.id]
if not matches: if not matches:
print(f"No document with ID '{args.id}' found.", file=sys.stderr) print(f"No document with ID '{args.id}' found.", file=sys.stderr)
sys.exit(1) sys.exit(1)
load_documents(api, args.user, matches) load_documents(api, matches)
elif args.command == "load-match": elif args.command == "load-match":
results = search_index(index, args.query) results = search_index(index, args.query)
if results: if results:
load_documents(api, args.user, results) load_documents(api, results)
else: else:
print("No matches found.", file=sys.stderr) print("No matches found.", file=sys.stderr)
sys.exit(1) sys.exit(1)

View file

@ -3,208 +3,278 @@
WebSocket Relay Test Harness WebSocket Relay Test Harness
This script creates a relay server with two WebSocket endpoints: This script creates a relay server with two WebSocket endpoints:
- /in - for test clients to connect to - /in - for test clients to connect to (speaks api-gateway protocol)
- /out - for reverse gateway to connect to - /out - for reverse gateway to connect to (speaks rev-gateway protocol)
Messages are bidirectionally relayed between the two connections. Clients on /in authenticate with a first-frame auth message:
{"type": "auth", "token": "..."}
The relay stores the token and injects it into each subsequent message
before forwarding to /out. Responses from /out are forwarded back to
the originating /in connection unchanged.
Usage: Usage:
python websocket_relay.py [--port PORT] [--host HOST] python websocket_relay.py [--port PORT] [--host HOST]
""" """
import asyncio import asyncio
import json
import logging import logging
import argparse import argparse
from aiohttp import web, WSMsgType from aiohttp import web, WSMsgType
import weakref from typing import Dict, Optional
from typing import Optional, Set
# Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
) )
logger = logging.getLogger("websocket_relay") logger = logging.getLogger("websocket_relay")
class InConnection:
def __init__(self, ws, conn_id):
self.ws = ws
self.conn_id = conn_id
self.token: Optional[str] = None
self.authenticated = False
class WebSocketRelay: class WebSocketRelay:
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
def __init__(self): def __init__(self):
self.in_connections: Set = weakref.WeakSet() self.in_connections: Dict[str, InConnection] = {}
self.out_connections: Set = weakref.WeakSet() self.out_connections: set = set()
self._conn_counter = 0
def _next_conn_id(self):
self._conn_counter += 1
return f"conn-{self._conn_counter}"
async def handle_in_connection(self, request): async def handle_in_connection(self, request):
"""Handle incoming connections on /in endpoint"""
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
self.in_connections.add(ws) conn_id = self._next_conn_id()
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") conn = InConnection(ws, conn_id)
self.in_connections[conn_id] = conn
logger.info(
f"New 'in' connection {conn_id}. "
f"Total in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
try: try:
async for msg in ws: async for msg in ws:
if msg.type == WSMsgType.TEXT: if msg.type == WSMsgType.TEXT:
data = msg.data await self._handle_in_message(conn, msg.data)
logger.info(f"IN → OUT: {data}")
await self._forward_to_out(data)
elif msg.type == WSMsgType.BINARY:
data = msg.data
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
await self._forward_to_out(data, binary=True)
elif msg.type == WSMsgType.ERROR: elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}") logger.error(
f"WebSocket error on 'in' connection "
f"{conn_id}: {ws.exception()}"
)
break break
else: else:
break break
except Exception as e: except Exception as e:
logger.error(f"Error in 'in' connection handler: {e}") logger.error(
f"Error in 'in' connection {conn_id}: {e}"
)
finally: finally:
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") del self.in_connections[conn_id]
logger.info(
f"'in' connection {conn_id} closed. "
f"Remaining in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
return ws return ws
async def handle_out_connection(self, request): async def _handle_in_message(self, conn, data):
"""Handle outgoing connections on /out endpoint"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.out_connections.add(ws)
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
try: try:
async for msg in ws: message = json.loads(data)
if msg.type == WSMsgType.TEXT: except json.JSONDecodeError:
data = msg.data logger.warning(
logger.info(f"OUT → IN: {data}") f"{conn.conn_id}: received non-JSON message"
await self._forward_to_in(data) )
elif msg.type == WSMsgType.BINARY:
data = msg.data
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
await self._forward_to_in(data, binary=True)
elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
break
else:
break
except Exception as e:
logger.error(f"Error in 'out' connection handler: {e}")
finally:
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
return ws
async def _forward_to_out(self, data, binary=False):
"""Forward message from 'in' to all 'out' connections"""
if not self.out_connections:
logger.warning("No 'out' connections available to forward message")
return return
closed_connections = [] if isinstance(message, dict) and message.get("type") == "auth":
conn.token = message.get("token", "")
conn.authenticated = True
logger.info(f"{conn.conn_id}: authenticated")
await conn.ws.send_json({
"type": "auth-ok",
"workspace": "relayed",
})
return
if not conn.authenticated:
await conn.ws.send_json({
"error": {
"message": "auth required",
"type": "auth-required",
},
"complete": True,
})
return
message["token"] = conn.token
message["_relay_conn"] = conn.conn_id
forwarded = json.dumps(message)
logger.info(f"IN {conn.conn_id} → OUT: {forwarded}")
await self._forward_to_out(forwarded)
async def handle_out_connection(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.out_connections.add(ws)
logger.info(
f"New 'out' connection. "
f"Total in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
await self._handle_out_message(msg.data)
elif msg.type == WSMsgType.ERROR:
logger.error(
f"WebSocket error on 'out' connection: "
f"{ws.exception()}"
)
break
else:
break
except Exception as e:
logger.error(f"Error in 'out' connection: {e}")
finally:
self.out_connections.discard(ws)
logger.info(
f"'out' connection closed. "
f"Remaining in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
return ws
async def _handle_out_message(self, data):
try:
message = json.loads(data)
except json.JSONDecodeError:
logger.warning("OUT: received non-JSON message")
return
conn_id = message.pop("_relay_conn", None)
forwarded = json.dumps(message)
logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}")
if conn_id and conn_id in self.in_connections:
conn = self.in_connections[conn_id]
try:
if not conn.ws.closed:
await conn.ws.send_str(forwarded)
except Exception as e:
logger.error(
f"Error forwarding to 'in' {conn_id}: {e}"
)
else:
await self._broadcast_to_in(forwarded)
async def _broadcast_to_in(self, data):
closed = []
for conn_id, conn in list(self.in_connections.items()):
try:
if conn.ws.closed:
closed.append(conn_id)
continue
await conn.ws.send_str(data)
except Exception as e:
logger.error(
f"Error broadcasting to 'in' {conn_id}: {e}"
)
closed.append(conn_id)
for conn_id in closed:
self.in_connections.pop(conn_id, None)
async def _forward_to_out(self, data):
closed = []
for ws in list(self.out_connections): for ws in list(self.out_connections):
try: try:
if ws.closed: if ws.closed:
closed_connections.append(ws) closed.append(ws)
continue continue
await ws.send_str(data)
if binary:
await ws.send_bytes(data)
else:
await ws.send_str(data)
except Exception as e: except Exception as e:
logger.error(f"Error forwarding to 'out' connection: {e}") logger.error(f"Error forwarding to 'out': {e}")
closed_connections.append(ws) closed.append(ws)
for ws in closed:
# Clean up closed connections self.out_connections.discard(ws)
for ws in closed_connections:
if ws in self.out_connections:
self.out_connections.discard(ws)
async def _forward_to_in(self, data, binary=False):
"""Forward message from 'out' to all 'in' connections"""
if not self.in_connections:
logger.warning("No 'in' connections available to forward message")
return
closed_connections = []
for ws in list(self.in_connections):
try:
if ws.closed:
closed_connections.append(ws)
continue
if binary:
await ws.send_bytes(data)
else:
await ws.send_str(data)
except Exception as e:
logger.error(f"Error forwarding to 'in' connection: {e}")
closed_connections.append(ws)
# Clean up closed connections
for ws in closed_connections:
if ws in self.in_connections:
self.in_connections.discard(ws)
async def create_app(relay): async def create_app(relay):
"""Create the web application with routes"""
app = web.Application() app = web.Application()
# Add routes app.router.add_get('/in/api/v1/socket', relay.handle_in_connection)
app.router.add_get('/in', relay.handle_in_connection)
app.router.add_get('/out', relay.handle_out_connection) app.router.add_get('/out', relay.handle_out_connection)
# Add a simple status endpoint
async def status(request): async def status(request):
status_info = { return web.json_response({
'in_connections': len(relay.in_connections), 'in_connections': len(relay.in_connections),
'out_connections': len(relay.out_connections), 'out_connections': len(relay.out_connections),
'status': 'running' 'status': 'running',
} })
return web.json_response(status_info)
app.router.add_get('/status', status) app.router.add_get('/status', status)
app.router.add_get('/', status) # Root also shows status app.router.add_get('/', status)
return app return app
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="WebSocket Relay Test Harness" description="WebSocket Relay Test Harness"
) )
parser.add_argument( parser.add_argument(
'--host', '--host',
default='localhost', default='localhost',
help='Host to bind to (default: localhost)' help='Host to bind to (default: localhost)',
) )
parser.add_argument( parser.add_argument(
'--port', '--port',
type=int, type=int,
default=8080, default=8080,
help='Port to bind to (default: 8080)' help='Port to bind to (default: 8080)',
) )
parser.add_argument( parser.add_argument(
'--verbose', '-v', '--verbose', '-v',
action='store_true', action='store_true',
help='Enable verbose logging' help='Enable verbose logging',
) )
args = parser.parse_args() args = parser.parse_args()
if args.verbose: if args.verbose:
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)
relay = WebSocketRelay() relay = WebSocketRelay()
print(f"Starting WebSocket Relay on {args.host}:{args.port}") print(f"Starting WebSocket Relay on {args.host}:{args.port}")
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in") print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket")
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out") print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
print(f" Status: http://{args.host}:{args.port}/status") print(f" Status: http://{args.host}:{args.port}/status")
print() print()
print("Usage:") print("Client protocol (same as api-gateway):")
print(f" Test client connects to: ws://{args.host}:{args.port}/in") print(' 1. Connect to /in/api/v1/socket')
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out") print(' 2. Send: {"type": "auth", "token": "tg_..."}')
print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}')
print(' 4. Send requests as normal')
web.run_app(create_app(relay), host=args.host, port=args.port) web.run_app(create_app(relay), host=args.host, port=args.port)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
max_concurrent = 0 max_concurrent = 0
processing_event = asyncio.Event() processing_event = asyncio.Event()
async def slow_process(message): async def slow_process(message, sender):
nonlocal concurrent_count, max_concurrent nonlocal concurrent_count, max_concurrent
concurrent_count += 1 concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count) max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
concurrent_count -= 1 concurrent_count -= 1
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = slow_process dispatcher._process_message = slow_process
sender = AsyncMock()
# Launch more tasks than max_workers # Launch more tasks than max_workers
messages = [ messages = [
{"id": f"msg-{i}", "service": "test", "request": {}} {"id": f"msg-{i}", "service": "test", "request": {}}
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
] ]
tasks = [ tasks = [
asyncio.create_task(dispatcher.handle_message(m)) asyncio.create_task(dispatcher.handle_message(m, sender))
for m in messages for m in messages
] ]
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
original_process = dispatcher._process_message original_process = dispatcher._process_message
async def tracking_process(message): async def tracking_process(message, sender):
nonlocal task_was_tracked nonlocal task_was_tracked
# During processing, our task should be in active_tasks # During processing, our task should be in active_tasks
if len(dispatcher.active_tasks) > 0: if len(dispatcher.active_tasks) > 0:
task_was_tracked = True task_was_tracked = True
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = tracking_process dispatcher._process_message = tracking_process
await dispatcher.handle_message( await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}} {"id": "test", "service": "test", "request": {}},
AsyncMock(),
) )
assert task_was_tracked assert task_was_tracked
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
"""Semaphore should be released even if processing raises.""" """Semaphore should be released even if processing raises."""
dispatcher = MessageDispatcher(max_workers=2) dispatcher = MessageDispatcher(max_workers=2)
async def failing_process(message): async def failing_process(message, sender):
raise RuntimeError("process failed") raise RuntimeError("process failed")
dispatcher._process_message = failing_process dispatcher._process_message = failing_process
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
# Should not deadlock — semaphore must be released on error # Should not deadlock — semaphore must be released on error
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await dispatcher.handle_message( await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}} {"id": "test", "service": "test", "request": {}},
AsyncMock(),
) )
# Semaphore should be back at max # Semaphore should be back at max
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
order = [] order = []
async def ordered_process(message): async def ordered_process(message, sender):
msg_id = message["id"] msg_id = message["id"]
order.append(f"start-{msg_id}") order.append(f"start-{msg_id}")
await asyncio.sleep(0.02) await asyncio.sleep(0.02)
order.append(f"end-{msg_id}") order.append(f"end-{msg_id}")
return {"id": msg_id, "response": {"ok": True}}
dispatcher._process_message = ordered_process dispatcher._process_message = ordered_process
sender = AsyncMock()
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)] messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages] tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# With semaphore=1, each message should complete before next starts # With semaphore=1, each message should complete before next starts

View file

@ -0,0 +1,56 @@
"""
Tests for ontology monitoring metrics.
"""
from trustgraph.query.ontology.monitoring import (
PerformanceMonitor,
_extract_metric_label,
)
def test_extract_metric_label_reads_unquoted_label_value():
metric_name = "cache_requests_total{cache_type=entity,component=ontology}"
assert _extract_metric_label(metric_name, "cache_type") == "entity"
def test_extract_metric_label_reads_quoted_label_value():
metric_name = 'cache_requests_total{cache_type="entity",component="ontology"}'
assert _extract_metric_label(metric_name, "cache_type") == "entity"
def test_extract_metric_label_returns_none_when_label_missing():
metric_name = "cache_requests_total{component=ontology}"
assert _extract_metric_label(metric_name, "cache_type") is None
def test_performance_report_ignores_counters_without_cache_type_label():
monitor = PerformanceMonitor({"enabled": False})
monitor.metrics_collector.increment(
"cache_requests_total",
labels={"component": "ontology"},
)
monitor.metrics_collector.increment(
"cache_type=not_a_label",
labels={"component": "ontology"},
)
monitor.metrics_collector.increment(
"cache_requests_total",
labels={"cache_type": "entity"},
)
monitor.metrics_collector.increment(
"cache_hits_total",
labels={"cache_type": "entity"},
)
report = monitor.get_performance_report()
assert report["cache_performance"] == {
"entity": {
"hit_rate": 1.0,
"total_requests": 1.0,
"total_hits": 1.0,
}
}

View file

@ -14,7 +14,7 @@ from rdflib.plugins.sparql.parserutils import CompValue
from trustgraph.schema import Term, IRI, LITERAL from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.algebra import ( from trustgraph.query.sparql.algebra import (
evaluate, _query_pattern, _eval_bgp, evaluate, materialise, _query_pattern, _eval_bgp,
) )
@ -28,6 +28,32 @@ def lit(v):
return Term(type=LITERAL, value=v) return Term(type=LITERAL, value=v)
def make_tc(query_return=None, query_side_effect=None):
"""Create a mock TriplesClient with both query() and query_gen() support."""
tc = AsyncMock()
if query_side_effect is not None:
tc.query.side_effect = query_side_effect
async def gen_side_effect(**kwargs):
results = await query_side_effect(**kwargs)
for r in results:
yield r
tc.query_gen = gen_side_effect
else:
items = query_return or []
tc.query.return_value = items
async def gen(**kwargs):
for item in items:
yield item
tc.query_gen = gen
return tc
def make_triple(s, p, o): def make_triple(s, p, o):
t = MagicMock() t = MagicMock()
t.s = s t.s = s
@ -84,6 +110,20 @@ def make_distinct(inner):
return node return node
def make_filter(inner, expr):
node = CompValue("Filter")
node.p = inner
node.expr = expr
return node
def make_minus(left, right):
node = CompValue("Minus")
node.p1 = left
node.p2 = right
return node
class TestQueryPattern: class TestQueryPattern:
"""Tests for _query_pattern — the leaf that calls TriplesClient.""" """Tests for _query_pattern — the leaf that calls TriplesClient."""
@ -136,15 +176,14 @@ class TestEvalBgp:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_pattern_all_variables(self): async def test_single_pattern_all_variables(self):
tc = AsyncMock()
triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc.query.return_value = [triple] tc = make_tc(query_return=[triple])
bgp = make_bgp( bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")), (Variable("s"), Variable("p"), Variable("o")),
) )
solutions = await evaluate(bgp, tc, collection="default", limit=100) solutions = await materialise(bgp, tc, collection="default", limit=100)
assert len(solutions) == 1 assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://s" assert solutions[0]["s"].iri == "http://s"
@ -153,43 +192,37 @@ class TestEvalBgp:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_pattern_bound_subject(self): async def test_single_pattern_bound_subject(self):
tc = AsyncMock() tc = make_tc(query_return=[
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("val")), make_triple(iri("http://s"), iri("http://p"), lit("val")),
] ])
bgp = make_bgp( bgp = make_bgp(
(URIRef("http://s"), Variable("p"), Variable("o")), (URIRef("http://s"), Variable("p"), Variable("o")),
) )
solutions = await evaluate(bgp, tc, collection="default") solutions = await materialise(bgp, tc, collection="default")
tc.query.assert_called_once() assert len(solutions) == 1
kwargs = tc.query.call_args.kwargs
assert "workspace" not in kwargs
assert kwargs["collection"] == "default"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_bgp_returns_empty_solution(self): async def test_empty_bgp_returns_empty_solution(self):
tc = AsyncMock() tc = make_tc()
bgp = make_bgp() bgp = make_bgp()
solutions = await evaluate(bgp, tc, collection="default") solutions = await materialise(bgp, tc, collection="default")
assert solutions == [{}] assert solutions == [{}]
tc.query.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_results_returns_empty(self): async def test_no_results_returns_empty(self):
tc = AsyncMock() tc = make_tc(query_return=[])
tc.query.return_value = []
bgp = make_bgp( bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")), (Variable("s"), Variable("p"), Variable("o")),
) )
solutions = await evaluate(bgp, tc, collection="default") solutions = await materialise(bgp, tc, collection="default")
assert solutions == [] assert solutions == []
@ -199,17 +232,16 @@ class TestEvaluate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_select_query_node(self): async def test_select_query_node(self):
tc = AsyncMock() tc = make_tc(query_return=[
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("o")), make_triple(iri("http://s"), iri("http://p"), lit("o")),
] ])
bgp = make_bgp( bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")), (Variable("s"), Variable("p"), Variable("o")),
) )
select = make_select(make_project(bgp, ["s", "p"])) select = make_select(make_project(bgp, ["s", "p"]))
solutions = await evaluate(select, tc, collection="default") solutions = await materialise(select, tc, collection="default")
assert len(solutions) == 1 assert len(solutions) == 1
assert "s" in solutions[0] assert "s" in solutions[0]
@ -220,10 +252,9 @@ class TestEvaluate:
async def test_workspace_never_in_query_calls(self): async def test_workspace_never_in_query_calls(self):
"""Verify that no matter the algebra structure, workspace is never """Verify that no matter the algebra structure, workspace is never
passed to TriplesClient.query().""" passed to TriplesClient.query()."""
tc = AsyncMock() tc = make_tc(query_return=[
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("o")), make_triple(iri("http://s"), iri("http://p"), lit("o")),
] ])
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
@ -231,72 +262,319 @@ class TestEvaluate:
make_union(bgp1, bgp2), ["s", "p", "o"] make_union(bgp1, bgp2), ["s", "p", "o"]
)) ))
await evaluate(tree, tc, collection="test-coll") await materialise(tree, tc, collection="test-coll")
for c in tc.query.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_join(self): async def test_join(self):
tc = AsyncMock() call_count = 0
tc.query.side_effect = [
[make_triple(iri("http://a"), iri("http://p"), lit("v"))], async def mock_query(**kwargs):
[make_triple(iri("http://a"), iri("http://q"), lit("w"))], nonlocal call_count
] call_count += 1
if call_count == 1:
return [make_triple(iri("http://a"), iri("http://p"), lit("v"))]
else:
return [make_triple(iri("http://a"), iri("http://q"), lit("w"))]
tc = make_tc(query_side_effect=mock_query)
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
tree = make_join(bgp1, bgp2) tree = make_join(bgp1, bgp2)
solutions = await evaluate(tree, tc, collection="default") solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1 assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://a" assert solutions[0]["s"].iri == "http://a"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_slice(self): async def test_slice(self):
tc = AsyncMock()
triples = [ triples = [
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
for i in range(5) for i in range(5)
] ]
tc.query.return_value = triples tc = make_tc(query_return=triples)
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_slice(bgp, start=1, length=2) tree = make_slice(bgp, start=1, length=2)
solutions = await evaluate(tree, tc, collection="default") solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 2 assert len(solutions) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_distinct(self): async def test_distinct(self):
tc = AsyncMock()
triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc.query.return_value = [triple, triple] tc = make_tc(query_return=[triple, triple])
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_distinct(bgp) tree = make_distinct(bgp)
solutions = await evaluate(tree, tc, collection="default") solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1 assert len(solutions) == 1
@pytest.mark.asyncio
async def test_minus_removes_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
hates = iri("http://example.com/hates")
charlie = iri("http://example.com/charlie")
left_triple = make_triple(alice, knows, bob)
right_triple2 = make_triple(alice, hates, charlie)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple]
elif pred and pred.iri == "http://example.com/hates":
return [right_triple2]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
right_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/hates"), Variable("r"))
)
tree = make_select(
make_project(
make_minus(left_bgp, right_bgp),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 0
@pytest.mark.asyncio
async def test_minus_no_shared_vars_preserves_all(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
left_triple = make_triple(alice, iri("http://example.com/p"), bob)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/p":
return [left_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/p"), Variable("o"))
)
right_bgp = make_bgp(
(Variable("x"), URIRef("http://example.com/q"), Variable("y"))
)
tree = make_select(
make_project(
make_minus(left_bgp, right_bgp),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_filter_exists_keeps_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
charlie = iri("http://example.com/charlie")
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple1, left_triple2]
elif pred and pred.iri == "http://example.com/likes":
return [exists_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
exists_bgp = make_bgp(
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
)
exists_expr = CompValue("Builtin_EXISTS")
exists_expr.graph = exists_bgp
tree = make_select(
make_project(
make_filter(left_bgp, exists_expr),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
result_objects = [s["o"].iri for s in solutions]
assert "http://example.com/bob" in result_objects
assert "http://example.com/charlie" not in result_objects
@pytest.mark.asyncio
async def test_filter_not_exists_removes_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
charlie = iri("http://example.com/charlie")
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple1, left_triple2]
elif pred and pred.iri == "http://example.com/likes":
return [exists_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
exists_bgp = make_bgp(
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
)
not_exists_expr = CompValue("Builtin_NOTEXISTS")
not_exists_expr.graph = exists_bgp
tree = make_select(
make_project(
make_filter(left_bgp, not_exists_expr),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
result_objects = [s["o"].iri for s in solutions]
assert "http://example.com/charlie" in result_objects
assert "http://example.com/bob" not in result_objects
@pytest.mark.asyncio
async def test_join_values_uses_bind_join(self):
"""When VALUES is joined with a BGP, the bind join should pass
the VALUES bindings into the BGP evaluation so the triple store
query is selective (not a wildcard)."""
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
queries_issued = []
async def mock_query(**kwargs):
queries_issued.append(kwargs)
s, p = kwargs.get("s"), kwargs.get("p")
if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows":
return [make_triple(alice, knows, bob)]
return []
tc = make_tc(query_side_effect=mock_query)
# VALUES ?s { <alice> }
values_node = CompValue("values")
values_node.var = [Variable("s")]
values_node.value = [[URIRef("http://example.com/alice")]]
values_node.res = None
to_multiset = CompValue("ToMultiSet")
to_multiset.p = values_node
bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
)
tree = make_join(to_multiset, bgp)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://example.com/alice"
assert solutions[0]["o"].iri == "http://example.com/bob"
# The key assertion: the BGP query should have received
# s=alice (bound from VALUES), NOT s=None (wildcard)
assert len(queries_issued) == 1
assert queries_issued[0]["s"] is not None
assert queries_issued[0]["s"].iri == "http://example.com/alice"
@pytest.mark.asyncio
async def test_join_values_multiple_bindings(self):
"""Bind join with multiple VALUES bindings."""
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
charlie = iri("http://example.com/charlie")
async def mock_query(**kwargs):
s = kwargs.get("s")
if s and s.iri == "http://example.com/alice":
return [make_triple(alice, knows, bob)]
elif s and s.iri == "http://example.com/bob":
return [make_triple(bob, knows, charlie)]
return []
tc = make_tc(query_side_effect=mock_query)
values_node = CompValue("values")
values_node.var = [Variable("s")]
values_node.value = [
[URIRef("http://example.com/alice")],
[URIRef("http://example.com/bob")],
]
values_node.res = None
to_multiset = CompValue("ToMultiSet")
to_multiset.p = values_node
bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
)
tree = make_join(to_multiset, bgp)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 2
subjects = {s["s"].iri for s in solutions}
assert subjects == {
"http://example.com/alice",
"http://example.com/bob",
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unsupported_node_returns_empty_solution(self): async def test_unsupported_node_returns_empty_solution(self):
tc = AsyncMock() tc = make_tc()
node = CompValue("SomethingUnknown") node = CompValue("SomethingUnknown")
solutions = await evaluate(node, tc, collection="default") solutions = await materialise(node, tc, collection="default")
assert solutions == [{}] assert solutions == [{}]
tc.query.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_compvalue_returns_empty_solution(self): async def test_non_compvalue_returns_empty_solution(self):
tc = AsyncMock() tc = make_tc()
solutions = await evaluate("not a node", tc, collection="default") solutions = await materialise("not a node", tc, collection="default")
assert solutions == [{}] assert solutions == [{}]

View file

@ -300,6 +300,438 @@ class TestBuiltinFunctions:
flags=None) flags=None)
assert evaluate_expression(expr, {"x": lit("hello")}) is False assert evaluate_expression(expr, {"x": lit("hello")}) is False
def test_substr_three_args(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(1),
length=Literal(4))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "2024"
def test_substr_two_args(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(6),
length=None)
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "03-15"
def test_substr_middle(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(6),
length=Literal(2))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "03"
def test_substr_null_start(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Variable("missing"),
length=None)
result = evaluate_expression(expr, {"x": lit("hello")})
assert result is None
def test_year(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 2024
def test_month(self):
from rdflib.term import Variable
expr = self._make_builtin("MONTH", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 3
def test_day(self):
from rdflib.term import Variable
expr = self._make_builtin("DAY", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 15
def test_hours(self):
from rdflib.term import Variable
expr = self._make_builtin("HOURS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 10
def test_minutes(self):
from rdflib.term import Variable
expr = self._make_builtin("MINUTES", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 30
def test_seconds(self):
from rdflib.term import Variable
expr = self._make_builtin("SECONDS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 45
def test_year_from_datetime(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 2024
def test_hours_from_date_returns_zero(self):
from rdflib.term import Variable
expr = self._make_builtin("HOURS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 0
def test_year_invalid_date(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("not-a-date")}
)
assert result is None
def test_floor(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.7")}) == 3
def test_floor_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-2.3")}) == -3
def test_floor_none(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_ceil(self):
from rdflib.term import Variable
expr = self._make_builtin("CEIL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.2")}) == 4
def test_ceil_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("CEIL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-2.7")}) == -2
def test_abs_positive(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("42")}) == 42
def test_abs_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-42")}) == 42
def test_abs_none(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_replace_simple(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal(" BC"),
replacement=Literal(""),
flags=None)
result = evaluate_expression(expr, {"x": lit("500 BC")})
assert result.type == LITERAL
assert result.value == "500"
def test_replace_regex(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal("[0-9]+"),
replacement=Literal("X"),
flags=None)
result = evaluate_expression(expr, {"x": lit("abc123def456")})
assert result.value == "abcXdefX"
def test_replace_case_insensitive(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal("hello"),
replacement=Literal("world"),
flags=Literal("i"))
result = evaluate_expression(expr, {"x": lit("HELLO there")})
assert result.value == "world there"
def test_round_up(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.7")}) == 4
def test_round_down(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.2")}) == 3
def test_round_none(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_strbefore(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRBEFORE",
arg1=Variable("x"), arg2=Literal("-"))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.value == "2024"
def test_strbefore_not_found(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRBEFORE",
arg1=Variable("x"), arg2=Literal("/"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == ""
def test_strafter(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRAFTER",
arg1=Variable("x"), arg2=Literal("-"))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.value == "03-15"
def test_strafter_not_found(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRAFTER",
arg1=Variable("x"), arg2=Literal("/"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == ""
def test_encode_for_uri(self):
from rdflib.term import Variable
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello world")})
assert result.value == "hello%20world"
def test_encode_for_uri_special_chars(self):
from rdflib.term import Variable
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("a/b?c=d&e")})
assert result.value == "a%2Fb%3Fc%3Dd%26e"
def test_langmatches_basic(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("en"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_subtag(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("en-US"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_wildcard(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("fr"), arg2=Literal("*"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_wildcard_empty(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal(""), arg2=Literal("*"))
assert evaluate_expression(expr, {}) is False
def test_langmatches_no_match(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("fr"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is False
def test_iri_constructor(self):
from rdflib.term import Variable
expr = self._make_builtin("IRI", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("http://example.com/test")}
)
assert result.type == IRI
assert result.iri == "http://example.com/test"
def test_uri_constructor(self):
from rdflib.term import Variable
expr = self._make_builtin("URI", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("http://example.com/test")}
)
assert result.type == IRI
assert result.iri == "http://example.com/test"
def test_bnode_no_arg(self):
expr = self._make_builtin("BNODE")
result = evaluate_expression(expr, {})
assert result.type == BLANK
assert len(result.id) > 0
def test_bnode_with_label(self):
from rdflib import Literal
expr = self._make_builtin("BNODE", arg=Literal("mynode"))
result = evaluate_expression(expr, {})
assert result.type == BLANK
assert result.id == "mynode"
def test_now(self):
import re as re_mod
expr = self._make_builtin("NOW")
result = evaluate_expression(expr, {})
assert result.type == LITERAL
assert result.datatype == XSD + "dateTime"
assert re_mod.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", result.value)
def test_tz_with_utc(self):
from rdflib.term import Variable
expr = self._make_builtin("TZ", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45+0000",
datatype=XSD + "dateTime")}
)
assert result.type == LITERAL
assert result.value == "+00:00"
def test_tz_no_timezone(self):
from rdflib.term import Variable
expr = self._make_builtin("TZ", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45",
datatype=XSD + "dateTime")}
)
assert result.value == ""
def test_rand(self):
expr = self._make_builtin("RAND")
result = evaluate_expression(expr, {})
assert isinstance(result, float)
assert 0.0 <= result < 1.0
def test_uuid(self):
import re as re_mod
expr = self._make_builtin("UUID")
result = evaluate_expression(expr, {})
assert result.type == IRI
assert result.iri.startswith("urn:uuid:")
uuid_part = result.iri[len("urn:uuid:"):]
assert re_mod.match(
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
uuid_part
)
def test_struuid(self):
import re as re_mod
expr = self._make_builtin("STRUUID")
result = evaluate_expression(expr, {})
assert result.type == LITERAL
assert re_mod.match(
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
result.value
)
def test_md5(self):
from rdflib.term import Variable
expr = self._make_builtin("MD5", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == "5d41402abc4b2a76b9719d911017c592"
def test_sha1(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA1", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d"
def test_sha256(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA256", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == (
"2cf24dba5fb0a30e26e83b2ac5b9e29e"
"1b161e5c1fa7425e73043362938b9824"
)
def test_sha512(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA512", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert len(result.value) == 128
def test_exists_with_callback(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("EXISTS", graph=graph)
cb = lambda g, s: True
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is True
def test_exists_callback_false(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("EXISTS", graph=graph)
cb = lambda g, s: False
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is False
def test_notexists_with_callback(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("NOTEXISTS", graph=graph)
cb = lambda g, s: True
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is False
def test_notexists_callback_false(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("NOTEXISTS", graph=graph)
cb = lambda g, s: False
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is True
class TestEffectiveBoolean: class TestEffectiveBoolean:

View file

@ -5,7 +5,7 @@ Tests for SPARQL solution sequence operations.
import pytest import pytest
from trustgraph.schema import Term, IRI, LITERAL from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.solutions import ( from trustgraph.query.sparql.solutions import (
hash_join, left_join, union, project, distinct, hash_join, left_join, minus, union, project, distinct,
order_by, slice_solutions, _terms_equal, _compatible, order_by, slice_solutions, _terms_equal, _compatible,
) )
@ -311,6 +311,30 @@ class TestOrderBy:
result = order_by(solutions, []) result = order_by(solutions, [])
assert len(result) == 1 assert len(result) == 1
def test_order_by_numeric_literals(self):
solutions = [
{"year": lit("1950")},
{"year": lit("700")},
{"year": lit("2000")},
{"year": lit("450")},
{"year": lit("1200")},
]
key_fns = [(lambda sol: sol.get("year"), True)]
result = order_by(solutions, key_fns)
values = [s["year"].value for s in result]
assert values == ["450", "700", "1200", "1950", "2000"]
def test_order_by_numeric_descending(self):
solutions = [
{"year": lit("1950")},
{"year": lit("700")},
{"year": lit("2000")},
]
key_fns = [(lambda sol: sol.get("year"), False)]
result = order_by(solutions, key_fns)
values = [s["year"].value for s in result]
assert values == ["2000", "1950", "700"]
class TestSlice: class TestSlice:
@ -343,3 +367,37 @@ class TestSlice:
solutions = [{"s": alice}, {"s": bob}] solutions = [{"s": alice}, {"s": bob}]
result = slice_solutions(solutions) result = slice_solutions(solutions)
assert len(result) == 2 assert len(result) == 2
class TestMinus:
def test_removes_compatible(self, alice, bob):
left = [{"s": alice}, {"s": bob}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/bob"
def test_empty_right_preserves_all(self, alice, bob):
left = [{"s": alice}, {"s": bob}]
result = minus(left, [])
assert len(result) == 2
def test_no_shared_variables_preserves_all(self, alice, bob):
left = [{"s": alice}]
right = [{"t": bob}]
result = minus(left, right)
assert len(result) == 1
def test_all_removed(self, alice):
left = [{"s": alice}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 0
def test_partial_shared_variables(self, alice, bob):
left = [{"s": alice, "p": lit("x")}, {"s": bob, "p": lit("y")}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/bob"

View file

View file

@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
""" """
import pytest import pytest
from unittest.mock import MagicMock, AsyncMock, patch import asyncio
from unittest.mock import MagicMock, AsyncMock, patch, ANY
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestWebSocketResponder:
"""Test cases for WebSocketResponder class"""
def test_websocket_responder_initialization(self):
"""Test WebSocketResponder initialization"""
responder = WebSocketResponder()
assert responder.response is None
assert responder.completed is False
@pytest.mark.asyncio
async def test_websocket_responder_send_method(self):
"""Test WebSocketResponder send method"""
responder = WebSocketResponder()
test_response = {"data": "test response"}
# Call send method
await responder.send(test_response)
# Verify response was stored
assert responder.response == test_response
@pytest.mark.asyncio
async def test_websocket_responder_call_method(self):
"""Test WebSocketResponder __call__ method"""
responder = WebSocketResponder()
test_response = {"result": "success"}
test_completed = True
# Call the responder
await responder(test_response, test_completed)
# Verify response and completed status were set
assert responder.response == test_response
assert responder.completed == test_completed
@pytest.mark.asyncio
async def test_websocket_responder_call_method_with_false_completion(self):
"""Test WebSocketResponder __call__ method with incomplete response"""
responder = WebSocketResponder()
test_response = {"partial": "data"}
test_completed = False
# Call the responder
await responder(test_response, test_completed)
# Verify response was set and completed is True (since send() always sets completed=True)
assert responder.response == test_response
assert responder.completed is True
class TestMessageDispatcher: class TestMessageDispatcher:
"""Test cases for MessageDispatcher class""" """Test cases for MessageDispatcher class"""
def test_message_dispatcher_initialization_with_defaults(self): def test_message_dispatcher_initialization_with_defaults(self):
"""Test MessageDispatcher initialization with default parameters"""
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
assert dispatcher.max_workers == 10 assert dispatcher.max_workers == 10
assert dispatcher.semaphore._value == 10 assert dispatcher.semaphore._value == 10
assert dispatcher.active_tasks == set() assert dispatcher.active_tasks == set()
assert dispatcher.backend is None assert dispatcher.backend is None
assert dispatcher.auth is None
assert dispatcher.dispatcher_manager is None assert dispatcher.dispatcher_manager is None
assert len(dispatcher.service_mapping) > 0 assert len(dispatcher.service_mapping) > 0
def test_message_dispatcher_initialization_with_custom_workers(self): def test_message_dispatcher_initialization_with_custom_workers(self):
"""Test MessageDispatcher initialization with custom max_workers"""
dispatcher = MessageDispatcher(max_workers=5) dispatcher = MessageDispatcher(max_workers=5)
assert dispatcher.max_workers == 5 assert dispatcher.max_workers == 5
assert dispatcher.semaphore._value == 5 assert dispatcher.semaphore._value == 5
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager') @patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager): def test_message_dispatcher_initialization_with_backend(
"""Test MessageDispatcher initialization with pulsar_client and config_receiver""" self, mock_dispatcher_manager,
):
mock_backend = MagicMock() mock_backend = MagicMock()
mock_config_receiver = MagicMock() mock_config_receiver = MagicMock()
mock_auth = MagicMock()
mock_dispatcher_instance = MagicMock() mock_dispatcher_instance = MagicMock()
mock_dispatcher_manager.return_value = mock_dispatcher_instance mock_dispatcher_manager.return_value = mock_dispatcher_instance
dispatcher = MessageDispatcher( dispatcher = MessageDispatcher(
max_workers=8, max_workers=8,
config_receiver=mock_config_receiver, config_receiver=mock_config_receiver,
backend=mock_backend backend=mock_backend,
auth=mock_auth,
timeout=300,
) )
assert dispatcher.max_workers == 8 assert dispatcher.max_workers == 8
assert dispatcher.backend == mock_backend assert dispatcher.backend == mock_backend
assert dispatcher.auth == mock_auth
assert dispatcher.dispatcher_manager == mock_dispatcher_instance assert dispatcher.dispatcher_manager == mock_dispatcher_instance
mock_dispatcher_manager.assert_called_once_with( mock_dispatcher_manager.assert_called_once_with(
mock_backend, mock_config_receiver, prefix="rev-gateway" mock_backend, mock_config_receiver,
auth=mock_auth, prefix="rev-gateway", timeout=300,
) )
def test_message_dispatcher_service_mapping(self): def test_message_dispatcher_service_mapping(self):
"""Test MessageDispatcher service mapping contains expected services"""
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
expected_services = [ expected_services = [
"text-completion", "graph-rag", "agent", "embeddings", "text-completion", "graph-rag", "agent", "embeddings",
"graph-embeddings", "triples", "document-load", "text-load", "graph-embeddings", "triples", "document-load", "text-load",
"flow", "knowledge", "config", "librarian", "document-rag" "flow", "knowledge", "config", "librarian", "document-rag",
] ]
for service in expected_services: for service in expected_services:
assert service in dispatcher.service_mapping assert service in dispatcher.service_mapping
# Test specific mappings
assert dispatcher.service_mapping["text-completion"] == "text-completion"
assert dispatcher.service_mapping["document-load"] == "document" assert dispatcher.service_mapping["document-load"] == "document"
assert dispatcher.service_mapping["text-load"] == "text-document" assert dispatcher.service_mapping["text-load"] == "text-document"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self): async def test_handle_message_without_dispatcher_manager(self):
"""Test MessageDispatcher handle_message without dispatcher manager"""
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.auth = MagicMock()
test_message = { dispatcher.auth.authenticate = AsyncMock(
"id": "test-123", return_value=MagicMock(workspace="default")
"service": "test-service", )
"request": {"data": "test"}
} sender = AsyncMock()
result = await dispatcher.handle_message(test_message) await dispatcher.handle_message(
{"id": "test-1", "service": "test", "request": {}},
assert result["id"] == "test-123" sender,
assert "error" in result["response"] )
assert "DispatcherManager not available" in result["response"]["error"]
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-1"
assert sent["error"]["message"] == "DispatcherManager not available"
assert sent["error"]["type"] == "error"
assert sent["complete"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_with_exception(self): async def test_handle_message_auth_failure(self):
"""Test MessageDispatcher handle_message with exception during processing"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
test_message = { side_effect=Exception("auth failure")
"id": "test-456", )
"service": "text-completion", dispatcher.dispatcher_manager = MagicMock()
"request": {"prompt": "test"}
} sender = AsyncMock()
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): await dispatcher.handle_message(
result = await dispatcher.handle_message(test_message) {"id": "test-2", "token": "bad", "service": "test", "request": {}},
sender,
assert result["id"] == "test-456" )
assert "error" in result["response"]
assert "Test error" in result["response"]["error"] sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-2"
assert "auth failure" in sent["error"]["message"]
assert sent["complete"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_global_service(self): async def test_handle_message_global_service(self):
"""Test MessageDispatcher handle_message with global service""" mock_dm = MagicMock()
mock_dispatcher_manager = MagicMock() mock_dm.invoke_global_service = AsyncMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"result": "success"}
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
test_message = { dispatcher.auth.authenticate = AsyncMock(
"id": "test-789", return_value=MagicMock(workspace="ws1")
"service": "text-completion", )
"request": {"prompt": "hello"}
} sender = AsyncMock()
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): with patch(
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): 'trustgraph.gateway.dispatch.manager.global_dispatchers',
result = await dispatcher.handle_message(test_message) {"text-completion": True},
):
assert result["id"] == "test-789" await dispatcher.handle_message(
assert result["response"] == {"result": "success"} {
mock_dispatcher_manager.invoke_global_service.assert_called_once() "id": "test-3",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hello"},
},
sender,
)
mock_dm.invoke_global_service.assert_called_once()
args, kwargs = mock_dm.invoke_global_service.call_args
assert args[0] == {"prompt": "hello"}
assert args[2] == "text-completion"
assert kwargs["workspace"] == "ws1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_flow_service(self): async def test_handle_message_flow_service(self):
"""Test MessageDispatcher handle_message with flow service""" mock_dm = MagicMock()
mock_dispatcher_manager = MagicMock() mock_dm.invoke_flow_service = AsyncMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"data": "flow_result"}
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
test_message = { dispatcher.auth.authenticate = AsyncMock(
"id": "test-flow-123", return_value=MagicMock(workspace="ws2")
"service": "document-rag", )
"request": {"query": "test"},
"flow": "custom-flow" sender = AsyncMock()
}
with patch(
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): 'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): ):
result = await dispatcher.handle_message(test_message) await dispatcher.handle_message(
{
assert result["id"] == "test-flow-123" "id": "test-4",
assert result["response"] == {"data": "flow_result"} "token": "tg_key",
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with( "service": "document-rag",
{"query": "test"}, mock_responder, "custom-flow", "document-rag" "request": {"query": "test"},
"flow": "my-flow",
},
sender,
)
mock_dm.invoke_flow_service.assert_called_once_with(
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_handle_message_incomplete_response(self): async def test_handle_message_responder_sends_frames(self):
"""Test MessageDispatcher handle_message with incomplete response""" mock_dm = MagicMock()
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock() async def fake_invoke(data, responder, svc, workspace=None):
mock_responder = MagicMock() await responder({"partial": 1}, False)
mock_responder.completed = False await responder({"partial": 2}, True)
mock_responder.response = None
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
test_message = { dispatcher.auth.authenticate = AsyncMock(
"id": "test-incomplete", return_value=MagicMock(workspace="ws1")
"service": "agent", )
"request": {"input": "test"}
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-5",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hi"},
},
sender,
)
assert sender.call_count == 2
first = sender.call_args_list[0][0][0]
second = sender.call_args_list[1][0][0]
assert first == {
"id": "test-5", "response": {"partial": 1}, "complete": False,
}
assert second == {
"id": "test-5", "response": {"partial": 2}, "complete": True,
} }
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-incomplete"
assert result["response"] == {"error": "No response received"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_shutdown(self): async def test_handle_message_workspace_from_identity(self):
"""Test MessageDispatcher shutdown method""" mock_dm = MagicMock()
import asyncio mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dm
# Create actual async tasks dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="derived-ws")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-6",
"token": "tg_key",
"service": "agent",
"request": {"question": "test"},
"flow": "default",
},
sender,
)
args = mock_dm.invoke_flow_service.call_args[0]
assert args[2] == "derived-ws"
@pytest.mark.asyncio
async def test_shutdown(self):
dispatcher = MessageDispatcher()
async def dummy_task(): async def dummy_task():
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
return "done"
task1 = asyncio.create_task(dummy_task()) task1 = asyncio.create_task(dummy_task())
task2 = asyncio.create_task(dummy_task()) task2 = asyncio.create_task(dummy_task())
dispatcher.active_tasks = {task1, task2} dispatcher.active_tasks = {task1, task2}
# Call shutdown
await dispatcher.shutdown() await dispatcher.shutdown()
# Verify tasks were completed
assert task1.done() assert task1.done()
assert task2.done() assert task2.done()
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_dispatcher_shutdown_with_no_tasks(self): async def test_shutdown_with_no_tasks(self):
"""Test MessageDispatcher shutdown with no active tasks"""
dispatcher = MessageDispatcher() dispatcher = MessageDispatcher()
# Call shutdown with no active tasks
await dispatcher.shutdown() await dispatcher.shutdown()
# Should complete without error assert dispatcher.active_tasks == set()
assert dispatcher.active_tasks == set()

View file

@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock
from aiohttp import WSMsgType, ClientWebSocketResponse from aiohttp import WSMsgType, ClientWebSocketResponse
import json import json
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run from trustgraph.rev_gateway.service import ReverseGateway, run
MOCK_PATCHES = [
'trustgraph.rev_gateway.service.IamAuth',
'trustgraph.rev_gateway.service.ConfigReceiver',
'trustgraph.rev_gateway.service.MessageDispatcher',
'trustgraph.rev_gateway.service.get_pubsub',
]
def make_gateway(**overrides):
config = {"websocket_uri": "ws://localhost:7650/out"}
config.update(overrides)
return ReverseGateway(**config)
class TestReverseGateway: class TestReverseGateway:
"""Test cases for ReverseGateway class""" """Test cases for ReverseGateway class"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with default parameters""" def test_reverse_gateway_initialization_defaults(
mock_backend = MagicMock() self, mock_get_pubsub, mock_dispatcher,
mock_get_pubsub.return_value = mock_backend mock_config_receiver, mock_iam_auth,
):
gateway = ReverseGateway() mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
assert gateway.websocket_uri == "ws://localhost:7650/out" assert gateway.websocket_uri == "ws://localhost:7650/out"
assert gateway.host == "localhost" assert gateway.host == "localhost"
assert gateway.port == 7650 assert gateway.port == 7650
@ -33,25 +49,22 @@ class TestReverseGateway:
assert gateway.max_workers == 10 assert gateway.max_workers == 10
assert gateway.running is False assert gateway.running is False
assert gateway.reconnect_delay == 3.0 assert gateway.reconnect_delay == 3.0
assert gateway.pulsar_host == "pulsar://pulsar:6650"
assert gateway.pulsar_api_key is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with custom parameters""" def test_reverse_gateway_initialization_custom_params(
mock_backend = MagicMock() self, mock_get_pubsub, mock_dispatcher,
mock_get_pubsub.return_value = mock_backend mock_config_receiver, mock_iam_auth,
):
gateway = ReverseGateway( mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(
websocket_uri="wss://example.com:8080/websocket", websocket_uri="wss://example.com:8080/websocket",
max_workers=20, max_workers=20,
pulsar_host="pulsar://custom:6650",
pulsar_api_key="test-key",
pulsar_listener="test-listener"
) )
assert gateway.websocket_uri == "wss://example.com:8080/websocket" assert gateway.websocket_uri == "wss://example.com:8080/websocket"
assert gateway.host == "example.com" assert gateway.host == "example.com"
assert gateway.port == 8080 assert gateway.port == 8080
@ -59,340 +72,360 @@ class TestReverseGateway:
assert gateway.path == "/websocket" assert gateway.path == "/websocket"
assert gateway.url == "wss://example.com:8080/websocket" assert gateway.url == "wss://example.com:8080/websocket"
assert gateway.max_workers == 20 assert gateway.max_workers == 20
assert gateway.pulsar_host == "pulsar://custom:6650"
assert gateway.pulsar_api_key == "test-key"
assert gateway.pulsar_listener == "test-listener"
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with WebSocket URI missing path""" def test_reverse_gateway_initialization_with_missing_path(
mock_backend = MagicMock() self, mock_get_pubsub, mock_dispatcher,
mock_get_pubsub.return_value = mock_backend mock_config_receiver, mock_iam_auth,
):
gateway = ReverseGateway(websocket_uri="ws://example.com") mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(websocket_uri="ws://example.com")
assert gateway.path == "/ws" assert gateway.path == "/ws"
assert gateway.url == "ws://example.com/ws" assert gateway.url == "ws://example.com/ws"
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with invalid WebSocket scheme""" def test_reverse_gateway_initialization_invalid_scheme(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"): with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
ReverseGateway(websocket_uri="http://example.com") make_gateway(websocket_uri="http://example.com")
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway initialization with missing hostname""" def test_reverse_gateway_initialization_missing_hostname(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must include hostname"): with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
ReverseGateway(websocket_uri="ws://") make_gateway(websocket_uri="ws://")
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway creates backend with authentication""" def test_reverse_gateway_iam_auth_created(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock() mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway( gateway = make_gateway(id="test-rev-gw")
pulsar_api_key="test-key",
pulsar_listener="test-listener" mock_iam_auth.assert_called_once_with(
backend=mock_backend,
id="test-rev-gw",
) )
# Verify get_pubsub was called with the correct parameters @patch(*MOCK_PATCHES[0:1])
mock_get_pubsub.assert_called_once_with( @patch(*MOCK_PATCHES[1:2])
pulsar_host="pulsar://pulsar:6650", @patch(*MOCK_PATCHES[2:3])
pulsar_api_key="test-key", @patch(*MOCK_PATCHES[3:4])
pulsar_listener="test-listener" def test_reverse_gateway_config_receiver_gets_auth(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_auth_instance = MagicMock()
mock_iam_auth.return_value = mock_auth_instance
gateway = make_gateway()
mock_config_receiver.assert_called_once_with(
mock_backend, auth=mock_auth_instance,
) )
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession') @patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_connect_success(
"""Test ReverseGateway successful connection""" self, mock_session_class, mock_get_pubsub,
mock_backend = MagicMock() mock_dispatcher, mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock() mock_session = AsyncMock()
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_session.ws_connect.return_value = mock_ws mock_session.ws_connect.return_value = mock_ws
mock_session_class.return_value = mock_session mock_session_class.return_value = mock_session
gateway = ReverseGateway() gateway = make_gateway()
result = await gateway.connect() result = await gateway.connect()
assert result is True assert result is True
assert gateway.session == mock_session assert gateway.session == mock_session
assert gateway.ws == mock_ws assert gateway.ws == mock_ws
mock_session.ws_connect.assert_called_once_with(gateway.url) mock_session.ws_connect.assert_called_once_with(gateway.url)
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession') @patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_connect_failure(
"""Test ReverseGateway connection failure""" self, mock_session_class, mock_get_pubsub,
mock_backend = MagicMock() mock_dispatcher, mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session.ws_connect.side_effect = Exception("Connection failed") mock_session.ws_connect.side_effect = Exception("Connection failed")
mock_session_class.return_value = mock_session mock_session_class.return_value = mock_session
gateway = ReverseGateway() gateway = make_gateway()
result = await gateway.connect() result = await gateway.connect()
assert result is False assert result is False
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_disconnect(
"""Test ReverseGateway disconnect""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
# Mock websocket and session
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session.closed = False mock_session.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
gateway.session = mock_session gateway.session = mock_session
await gateway.disconnect() await gateway.disconnect()
mock_ws.close.assert_called_once() mock_ws.close.assert_called_once()
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
assert gateway.ws is None assert gateway.ws is None
assert gateway.session is None assert gateway.session is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_send_message(
"""Test ReverseGateway send message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
# Mock websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"} test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message) await gateway.send_message(test_message)
mock_ws.send_str.assert_called_once_with(json.dumps(test_message)) mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_send_message_closed_connection(
"""Test ReverseGateway send message with closed connection""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
# Mock closed websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = True mock_ws.closed = True
gateway.ws = mock_ws gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"} test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message) await gateway.send_message(test_message)
# Should not call send_str on closed connection
mock_ws.send_str.assert_not_called() mock_ws.send_str.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_handle_message(
"""Test ReverseGateway handle message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
mock_dispatcher_instance = AsyncMock()
mock_dispatcher_instance.handle_message.return_value = {"response": "success"} mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway() gateway = make_gateway()
# Mock send_message
gateway.send_message = AsyncMock()
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
await gateway.handle_message(test_message)
mock_dispatcher_instance.handle_message.assert_called_once_with({
"id": "test",
"service": "test-service",
"request": {"data": "test"}
})
gateway.send_message.assert_called_once_with({"response": "success"})
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message with invalid JSON"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock send_message
gateway.send_message = AsyncMock() gateway.send_message = AsyncMock()
test_message = 'invalid json' test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
# Should not raise exception
await gateway.handle_message(test_message) await gateway.handle_message(test_message)
# Should not call send_message due to error mock_dispatcher_instance.handle_message.assert_called_once_with(
{
"id": "test",
"service": "test-service",
"request": {"data": "test"},
},
gateway.send_message,
)
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.send_message = AsyncMock()
await gateway.handle_message('invalid json')
gateway.send_message.assert_not_called() gateway.send_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_text_message(
"""Test ReverseGateway listen with text message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True gateway.running = True
# Mock websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock() gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.type = WSMsgType.TEXT mock_msg.type = WSMsgType.TEXT
mock_msg.data = '{"test": "message"}' mock_msg.data = '{"test": "message"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen() await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "message"}') gateway.handle_message.assert_called_once_with('{"test": "message"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_binary_message(
"""Test ReverseGateway listen with binary message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True gateway.running = True
# Mock websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock() gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.type = WSMsgType.BINARY mock_msg.type = WSMsgType.BINARY
mock_msg.data = b'{"test": "binary"}' mock_msg.data = b'{"test": "binary"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen() await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "binary"}') gateway.handle_message.assert_called_once_with('{"test": "binary"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_close_message(
"""Test ReverseGateway listen with close message""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True gateway.running = True
# Mock websocket
mock_ws = AsyncMock() mock_ws = AsyncMock()
mock_ws.closed = False mock_ws.closed = False
gateway.ws = mock_ws gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock() gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.type = WSMsgType.CLOSE mock_msg.type = WSMsgType.CLOSE
# Mock receive to return close message
mock_ws.receive.return_value = mock_msg mock_ws.receive.return_value = mock_msg
await gateway.listen() await gateway.listen()
# Should not call handle_message for close message
gateway.handle_message.assert_not_called() gateway.handle_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_shutdown(
"""Test ReverseGateway shutdown""" self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock() mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock() mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway() gateway = make_gateway()
gateway.running = True gateway.running = True
# Mock disconnect
gateway.disconnect = AsyncMock() gateway.disconnect = AsyncMock()
await gateway.shutdown() await gateway.shutdown()
@ -402,46 +435,50 @@ class TestReverseGateway:
gateway.disconnect.assert_called_once() gateway.disconnect.assert_called_once()
mock_backend.close.assert_called_once() mock_backend.close.assert_called_once()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): @patch(*MOCK_PATCHES[3:4])
"""Test ReverseGateway stop""" def test_reverse_gateway_stop(
mock_backend = MagicMock() self, mock_get_pubsub, mock_dispatcher,
mock_get_pubsub.return_value = mock_backend mock_config_receiver, mock_iam_auth,
):
gateway = ReverseGateway() mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True gateway.running = True
gateway.stop() gateway.stop()
assert gateway.running is False assert gateway.running is False
class TestReverseGatewayRun: class TestReverseGatewayRun:
"""Test cases for ReverseGateway run method""" """Test cases for ReverseGateway run method"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch(*MOCK_PATCHES[0:1])
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch(*MOCK_PATCHES[1:2])
@patch('trustgraph.rev_gateway.service.get_pubsub') @patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_run_successful_cycle(
"""Test ReverseGateway run method with successful connect/listen cycle""" self, mock_get_pubsub, mock_dispatcher,
mock_backend = MagicMock() mock_config_receiver, mock_iam_auth,
mock_get_pubsub.return_value = mock_backend ):
mock_get_pubsub.return_value = MagicMock()
mock_auth_instance = AsyncMock()
mock_iam_auth.return_value = mock_auth_instance
mock_config_receiver_instance = AsyncMock() mock_config_receiver_instance = AsyncMock()
mock_config_receiver.return_value = mock_config_receiver_instance mock_config_receiver.return_value = mock_config_receiver_instance
gateway = ReverseGateway() gateway = make_gateway()
# Mock methods
gateway.connect = AsyncMock(return_value=True)
gateway.listen = AsyncMock() gateway.listen = AsyncMock()
gateway.disconnect = AsyncMock() gateway.disconnect = AsyncMock()
gateway.shutdown = AsyncMock() gateway.shutdown = AsyncMock()
# Stop after one iteration
call_count = 0 call_count = 0
async def mock_connect(): async def mock_connect():
nonlocal call_count nonlocal call_count
@ -451,91 +488,13 @@ class TestReverseGatewayRun:
else: else:
gateway.running = False gateway.running = False
return False return False
gateway.connect = mock_connect gateway.connect = mock_connect
await gateway.run() await gateway.run()
mock_auth_instance.start.assert_called_once()
mock_config_receiver_instance.start.assert_called_once() mock_config_receiver_instance.start.assert_called_once()
gateway.listen.assert_called_once() gateway.listen.assert_called_once()
# disconnect is called twice: once in the main loop, once in shutdown
assert gateway.disconnect.call_count == 2 assert gateway.disconnect.call_count == 2
gateway.shutdown.assert_called_once() gateway.shutdown.assert_called_once()
class TestReverseGatewayArgs:
"""Test cases for argument parsing and run function"""
def test_parse_args_defaults(self):
"""Test parse_args with default values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway']
try:
args = parse_args()
assert args.websocket_uri is None
assert args.max_workers == 10
assert args.pulsar_host is None
assert args.pulsar_api_key is None
assert args.pulsar_listener is None
finally:
sys.argv = original_argv
def test_parse_args_custom_values(self):
"""Test parse_args with custom values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = [
'reverse-gateway',
'--websocket-uri', 'ws://custom:8080/ws',
'--max-workers', '20',
'--pulsar-host', 'pulsar://custom:6650',
'--pulsar-api-key', 'test-key',
'--pulsar-listener', 'test-listener'
]
try:
args = parse_args()
assert args.websocket_uri == 'ws://custom:8080/ws'
assert args.max_workers == 20
assert args.pulsar_host == 'pulsar://custom:6650'
assert args.pulsar_api_key == 'test-key'
assert args.pulsar_listener == 'test-listener'
finally:
sys.argv = original_argv
@patch('trustgraph.rev_gateway.service.ReverseGateway')
@patch('asyncio.run')
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
"""Test run function"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway', '--max-workers', '15']
try:
mock_gateway_instance = MagicMock()
mock_gateway_instance.url = "ws://localhost:7650/out"
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
mock_gateway_class.return_value = mock_gateway_instance
run()
mock_gateway_class.assert_called_once_with(
websocket_uri=None,
max_workers=15,
pulsar_host=None,
pulsar_api_key=None,
pulsar_listener=None
)
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
finally:
sys.argv = original_argv

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from typing import Any from typing import Any
from . request_response_spec import RequestResponse, RequestResponseSpec from . request_response_spec import RequestResponse, RequestResponseSpec
@ -44,6 +45,60 @@ def from_value(x: Any) -> Any:
return Term(type=LITERAL, value=str(x)) return Term(type=LITERAL, value=str(x))
class TriplesClient(RequestResponse): class TriplesClient(RequestResponse):
async def query_gen(self, s=None, p=None, o=None, limit=20,
collection="default",
batch_size=20, timeout=30, g=None):
"""Async generator yielding Triple objects as batches arrive."""
queue = asyncio.Queue()
done = False
async def recipient(resp):
if resp.error:
raise RuntimeError(resp.error.message)
batch = [
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
for v in resp.triples
]
await queue.put(batch)
if resp.is_final:
await queue.put(None)
return resp.is_final
# Launch the streaming request as a background task
task = asyncio.ensure_future(self.request(
TriplesQueryRequest(
s=from_value(s),
p=from_value(p),
o=from_value(o),
limit=limit,
collection=collection,
streaming=True,
batch_size=batch_size,
g=g,
),
timeout=timeout,
recipient=recipient,
))
try:
while True:
batch = await queue.get()
if batch is None:
break
for triple in batch:
yield triple
finally:
if not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
async def query(self, s=None, p=None, o=None, limit=20, async def query(self, s=None, p=None, o=None, limit=20,
collection="default", collection="default",
timeout=30, g=None): timeout=30, g=None):

View file

@ -7,7 +7,7 @@ Provides semantic query understanding, ontology matching, and answer generation.
from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse
from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType
from .ontology_matcher import OntologyMatcher, QueryOntologySubset from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset
from .backend_router import BackendRouter, BackendType, QueryRoute from .backend_router import BackendRouter, BackendType, QueryRoute
from .sparql_generator import SPARQLGenerator, SPARQLQuery from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
@ -27,7 +27,7 @@ __all__ = [
'QuestionType', 'QuestionType',
# Ontology matching # Ontology matching
'OntologyMatcher', 'OntologyMatcherForQueries',
'QueryOntologySubset', 'QueryOntologySubset',
# Backend routing # Backend routing

View file

@ -4,6 +4,7 @@ Provides comprehensive monitoring of system performance, query patterns, and res
""" """
import logging import logging
import re
import time import time
import asyncio import asyncio
import inspect import inspect
@ -276,6 +277,26 @@ class MetricsCollector:
return f"{name}{{{label_str}}}" return f"{name}{{{label_str}}}"
def _extract_metric_label(metric_name: str, label: str) -> Optional[str]:
"""Extract a label value from an internal metric key."""
labels_start = metric_name.find('{')
labels_end = metric_name.find('}', labels_start + 1)
if labels_start == -1 or labels_end == -1:
return None
labels = metric_name[labels_start + 1:labels_end]
label_match = re.search(
rf'(?:^|,){re.escape(label)}=(?:"([^"]*)"|([^,]*))',
labels,
)
if not label_match:
return None
quoted_value, unquoted_value = label_match.groups()
return quoted_value if quoted_value is not None else unquoted_value
class PerformanceMonitor: class PerformanceMonitor:
"""Monitors system performance and component health.""" """Monitors system performance and component health."""
@ -474,8 +495,8 @@ class PerformanceMonitor:
# Cache performance # Cache performance
cache_types = set() cache_types = set()
for metric_name in self.metrics_collector.counters.keys(): for metric_name in self.metrics_collector.counters.keys():
if 'cache_type=' in metric_name: cache_type = _extract_metric_label(metric_name, 'cache_type')
cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0] if cache_type is not None:
cache_types.add(cache_type) cache_types.add(cache_type)
for cache_type in cache_types: for cache_type in cache_types:

View file

@ -7,10 +7,10 @@ import logging
from typing import List, Dict, Any, Set, Optional from typing import List, Dict, Any, Set, Optional
from dataclasses import dataclass from dataclasses import dataclass
from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader from trustgraph.extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder from trustgraph.extract.kg.ontology.ontology_embedder import OntologyEmbedder
from ...extract.kg.ontology.text_processor import TextSegment from trustgraph.extract.kg.ontology.text_processor import TextSegment
from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset from trustgraph.extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
from .question_analyzer import QuestionComponents, QuestionType from .question_analyzer import QuestionComponents, QuestionType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -8,13 +8,13 @@ from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from ....flow.flow_processor import FlowProcessor from trustgraph.base.flow_processor import FlowProcessor
from ....tables.config import ConfigTableStore from trustgraph.tables.config import ConfigTableStore
from ...extract.kg.ontology.ontology_loader import OntologyLoader from trustgraph.extract.kg.ontology.ontology_loader import OntologyLoader
from ...extract.kg.ontology.vector_store import InMemoryVectorStore from trustgraph.extract.kg.ontology.vector_store import InMemoryVectorStore
from .question_analyzer import QuestionAnalyzer, QuestionComponents from .question_analyzer import QuestionAnalyzer, QuestionComponents
from .ontology_matcher import OntologyMatcher, QueryOntologySubset from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset
from .backend_router import BackendRouter, QueryRoute, BackendType from .backend_router import BackendRouter, QueryRoute, BackendType
from .sparql_generator import SPARQLGenerator, SPARQLQuery from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
@ -105,7 +105,7 @@ class OntoRAGQueryService(FlowProcessor):
# Initialize ontology matcher # Initialize ontology matcher
matcher_config = self.config.get('ontology_matcher', {}) matcher_config = self.config.get('ontology_matcher', {})
self.ontology_matcher = OntologyMatcher( self.ontology_matcher = OntologyMatcherForQueries(
vector_store=self.vector_store, vector_store=self.vector_store,
embedding_service=self.embedding_service, embedding_service=self.embedding_service,
config=matcher_config config=matcher_config

View file

@ -28,7 +28,7 @@ try:
except ImportError: except ImportError:
CASSANDRA_AVAILABLE = False CASSANDRA_AVAILABLE = False
from ....tables.config import ConfigTableStore from trustgraph.tables.config import ConfigTableStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -4,6 +4,10 @@ SPARQL algebra evaluator.
Recursively evaluates an rdflib SPARQL algebra tree by issuing triple Recursively evaluates an rdflib SPARQL algebra tree by issuing triple
pattern queries via TriplesClient (streaming) and performing in-memory pattern queries via TriplesClient (streaming) and performing in-memory
joins, filters, and projections. joins, filters, and projections.
Handlers are async generators that yield solutions incrementally.
Blocking operators (joins, sort, group, distinct) materialise their
upstream into a list at the boundary, then yield results.
""" """
import logging import logging
@ -17,7 +21,7 @@ from ... knowledge import Uri
from ... knowledge import Literal as KgLiteral from ... knowledge import Literal as KgLiteral
from . parser import rdflib_term_to_term from . parser import rdflib_term_to_term
from . solutions import ( from . solutions import (
hash_join, left_join, union, project, distinct, hash_join, left_join, minus, union, project, distinct,
order_by, slice_solutions, _term_key, order_by, slice_solutions, _term_key,
) )
from . expressions import evaluate_expression, _effective_boolean from . expressions import evaluate_expression, _effective_boolean
@ -34,56 +38,56 @@ async def evaluate(node, triples_client, collection, limit=10000):
""" """
Evaluate a SPARQL algebra node. Evaluate a SPARQL algebra node.
Args: Yields solutions (dicts mapping variable names to Term values)
node: rdflib CompValue algebra node incrementally as an async generator.
triples_client: TriplesClient instance for triple pattern queries
collection: collection identifier
limit: safety limit on results
Returns:
list of solutions (dicts mapping variable names to Term values)
""" """
if not isinstance(node, CompValue): if not isinstance(node, CompValue):
logger.warning(f"Expected CompValue, got {type(node)}: {node}") logger.warning(f"Expected CompValue, got {type(node)}: {node}")
return [{}] yield {}
return
name = node.name name = node.name
handler = _HANDLERS.get(name) handler = _HANDLERS.get(name)
if handler is None: if handler is None:
logger.warning(f"Unsupported algebra node: {name}") logger.warning(f"Unsupported algebra node: {name}")
return [{}] yield {}
return
return await handler(node, triples_client, collection, limit) async for sol in handler(node, triples_client, collection, limit):
yield sol
# --- Node handlers --- async def materialise(node, triples_client, collection, limit=10000):
"""Collect all solutions from evaluate() into a list."""
return [sol async for sol in evaluate(node, triples_client, collection, limit)]
# --- Node handlers (async generators) ---
async def _eval_select_query(node, tc, collection, limit): async def _eval_select_query(node, tc, collection, limit):
"""Evaluate a SelectQuery node.""" async for sol in evaluate(node.p, tc, collection, limit):
return await evaluate(node.p, tc, collection, limit) yield sol
async def _eval_project(node, tc, collection, limit): async def _eval_project(node, tc, collection, limit):
"""Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, collection, limit)
variables = [str(v) for v in node.PV] variables = [str(v) for v in node.PV]
return project(solutions, variables) async for sol in evaluate(node.p, tc, collection, limit):
yield {v: sol[v] for v in variables if v in sol}
async def _eval_bgp(node, tc, collection, limit): async def _eval_bgp(node, tc, collection, limit):
""" """
Evaluate a Basic Graph Pattern. Evaluate a Basic Graph Pattern.
Issues streaming triple pattern queries and joins results. Patterns Patterns are ordered by selectivity and evaluated sequentially.
are ordered by selectivity (more bound terms first) and evaluated For the final pattern, results stream directly from the triple store.
sequentially with bound-variable substitution.
""" """
triples = node.triples triples = node.triples
if not triples: if not triples:
return [{}] yield {}
return
# Sort patterns by selectivity: more bound terms = more selective
def selectivity(pattern): def selectivity(pattern):
return sum(1 for t in pattern if not isinstance(t, Variable)) return sum(1 for t in pattern if not isinstance(t, Variable))
@ -91,55 +95,222 @@ async def _eval_bgp(node, tc, collection, limit):
enumerate(triples), key=lambda x: -selectivity(x[1]) enumerate(triples), key=lambda x: -selectivity(x[1])
) )
# For all patterns except the last, we must materialise intermediate
# solutions because each pattern depends on bindings from prior ones.
# The last pattern streams directly.
solutions = [{}] solutions = [{}]
for _, pattern in sorted_patterns: for pattern_idx, (_, pattern) in enumerate(sorted_patterns):
s_tmpl, p_tmpl, o_tmpl = pattern s_tmpl, p_tmpl, o_tmpl = pattern
is_last = (pattern_idx == len(sorted_patterns) - 1)
new_solutions = [] if is_last:
# Stream the final pattern — yield as triples arrive
count = 0
for sol in solutions:
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
for sol in solutions: async for triple in tc.query_gen(
# Substitute known bindings into the pattern s=s_val, p=p_val, o=o_val,
s_val = _resolve_term(s_tmpl, sol) limit=limit, collection=collection,
p_val = _resolve_term(p_tmpl, sol) ):
o_val = _resolve_term(o_tmpl, sol) binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
yield binding
count += 1
if count >= limit:
return
else:
# Materialise intermediate patterns
new_solutions = []
for sol in solutions:
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
# Query the triples store async for triple in tc.query_gen(
results = await _query_pattern( s=s_val, p=p_val, o=o_val,
tc, s_val, p_val, o_val, collection, limit limit=limit, collection=collection,
) ):
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
new_solutions.append(binding)
# Map results back to variable bindings, solutions = new_solutions
# converting Uri/Literal to Term objects if not solutions:
for triple in results: return
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
new_solutions.append(binding)
solutions = new_solutions
if not solutions: # --- Blocking operators: materialise upstream, then yield ---
break
return solutions[:limit] def _is_small_node(node):
"""Check if a node is likely to produce a small number of solutions."""
if not isinstance(node, CompValue):
return False
if node.name in ("values", "ToMultiSet"):
return True
if node.name == "Extend" and hasattr(node, "p"):
return _is_small_node(node.p)
return False
async def _eval_join(node, tc, collection, limit): async def _eval_join(node, tc, collection, limit):
"""Evaluate a Join node.""" # Bind join: if one side is small (e.g. VALUES), materialise it and
left = await evaluate(node.p1, tc, collection, limit) # substitute its bindings into the other side's evaluation. This
right = await evaluate(node.p2, tc, collection, limit) # turns wildcard BGP queries into selective ones.
return hash_join(left, right)[:limit] if _is_small_node(node.p1):
yield_from = _bind_join(node.p1, node.p2, tc, collection, limit)
elif _is_small_node(node.p2):
yield_from = _bind_join(node.p2, node.p1, tc, collection, limit)
else:
yield_from = _hash_join(node, tc, collection, limit)
async for sol in yield_from:
yield sol
async def _hash_join(node, tc, collection, limit):
left = await materialise(node.p1, tc, collection, limit)
right = await materialise(node.p2, tc, collection, limit)
for sol in hash_join(left, right)[:limit]:
yield sol
async def _bind_join(small_node, big_node, tc, collection, limit):
"""Iterate over the small side and inject bindings into the big side."""
small_sols = await materialise(small_node, tc, collection, limit)
count = 0
for binding in small_sols:
async for sol in _evaluate_with_bindings(
big_node, binding, tc, collection, limit
):
yield sol
count += 1
if count >= limit:
return
def _merge_compatible(left, right):
"""Merge two solutions if compatible (shared vars have equal values)."""
merged = dict(left)
for k, v in right.items():
if k in merged:
if _term_key(merged[k]) != _term_key(v):
return None
else:
merged[k] = v
return merged
async def _evaluate_with_bindings(node, bindings, tc, collection, limit):
"""Evaluate a node with pre-seeded variable bindings.
For BGP nodes, the bindings are injected so _resolve_term sees them,
turning wildcard queries into selective ones. For other node types,
evaluate normally and merge/filter against the bindings.
"""
if isinstance(node, CompValue) and node.name == "BGP":
async for sol in _eval_bgp_with_bindings(
node, bindings, tc, collection, limit
):
yield sol
else:
async for sol in evaluate(node, tc, collection, limit):
merged = _merge_compatible(bindings, sol)
if merged is not None:
yield merged
async def _eval_bgp_with_bindings(node, bindings, tc, collection, limit):
"""Evaluate a BGP with pre-seeded bindings so variables resolve to terms."""
triples = node.triples
if not triples:
yield dict(bindings)
return
def selectivity(pattern):
score = 0
for t in pattern:
if not isinstance(t, Variable):
score += 1
elif str(t) in bindings:
score += 1
return score
sorted_patterns = sorted(
enumerate(triples), key=lambda x: -selectivity(x[1])
)
solutions = [dict(bindings)]
for pattern_idx, (_, pattern) in enumerate(sorted_patterns):
s_tmpl, p_tmpl, o_tmpl = pattern
is_last = (pattern_idx == len(sorted_patterns) - 1)
if is_last:
count = 0
for sol in solutions:
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
async for triple in tc.query_gen(
s=s_val, p=p_val, o=o_val,
limit=limit, collection=collection,
):
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
yield binding
count += 1
if count >= limit:
return
else:
new_solutions = []
for sol in solutions:
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
async for triple in tc.query_gen(
s=s_val, p=p_val, o=o_val,
limit=limit, collection=collection,
):
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
new_solutions.append(binding)
solutions = new_solutions
if not solutions:
return
async def _eval_left_join(node, tc, collection, limit): async def _eval_left_join(node, tc, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL).""" # Buffer right side for hash index; stream left through probe
left_sols = await evaluate(node.p1, tc, collection, limit) left_sols = await materialise(node.p1, tc, collection, limit)
right_sols = await evaluate(node.p2, tc, collection, limit) right_sols = await materialise(node.p2, tc, collection, limit)
filter_fn = None filter_fn = None
if hasattr(node, "expr") and node.expr is not None: if hasattr(node, "expr") and node.expr is not None:
@ -149,42 +320,35 @@ async def _eval_left_join(node, tc, collection, limit):
evaluate_expression(expr, sol) evaluate_expression(expr, sol)
) )
return left_join(left_sols, right_sols, filter_fn)[:limit] for sol in left_join(left_sols, right_sols, filter_fn)[:limit]:
yield sol
async def _eval_union(node, tc, collection, limit): async def _eval_minus(node, tc, collection, limit):
"""Evaluate a Union node.""" left = await materialise(node.p1, tc, collection, limit)
left = await evaluate(node.p1, tc, collection, limit) right = await materialise(node.p2, tc, collection, limit)
right = await evaluate(node.p2, tc, collection, limit) for sol in minus(left, right):
return union(left, right)[:limit] yield sol
async def _eval_filter(node, tc, collection, limit):
"""Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, collection, limit)
expr = node.expr
return [
sol for sol in solutions
if _effective_boolean(evaluate_expression(expr, sol))
]
async def _eval_distinct(node, tc, collection, limit): async def _eval_distinct(node, tc, collection, limit):
"""Evaluate a Distinct node.""" seen = set()
solutions = await evaluate(node.p, tc, collection, limit) async for sol in evaluate(node.p, tc, collection, limit):
return distinct(solutions) key = tuple(sorted(
(k, _term_key(v)) for k, v in sol.items()
))
if key not in seen:
seen.add(key)
yield sol
async def _eval_reduced(node, tc, collection, limit): async def _eval_reduced(node, tc, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined).""" async for sol in _eval_distinct(node, tc, collection, limit):
# Treat same as Distinct yield sol
solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions)
async def _eval_order_by(node, tc, collection, limit): async def _eval_order_by(node, tc, collection, limit):
"""Evaluate an OrderBy node.""" solutions = await materialise(node.p, tc, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
key_fns = [] key_fns = []
for cond in node.expr: for cond in node.expr:
@ -196,36 +360,104 @@ async def _eval_order_by(node, tc, collection, limit):
ascending, ascending,
)) ))
else: else:
# Simple variable or expression
key_fns.append(( key_fns.append((
lambda sol, e=cond: evaluate_expression(e, sol), lambda sol, e=cond: evaluate_expression(e, sol),
True, True,
)) ))
return order_by(solutions, key_fns) for sol in order_by(solutions, key_fns):
yield sol
# --- Streamable operators ---
async def _eval_slice(node, tc, collection, limit): async def _eval_slice(node, tc, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET).""" offset = node.start or 0
# Pass tighter limit downstream if possible length = node.length
inner_limit = limit skipped = 0
if node.length is not None: emitted = 0
offset = node.start or 0
inner_limit = min(limit, offset + node.length)
solutions = await evaluate(node.p, tc, collection, inner_limit) async for sol in evaluate(node.p, tc, collection, limit):
return slice_solutions(solutions, node.start or 0, node.length) if skipped < offset:
skipped += 1
continue
yield sol
emitted += 1
if length is not None and emitted >= length:
return
async def _eval_union(node, tc, collection, limit):
async for sol in evaluate(node.p1, tc, collection, limit):
yield sol
async for sol in evaluate(node.p2, tc, collection, limit):
yield sol
async def _check_exists(graph_node, sol, tc, collection, limit):
"""Evaluate an EXISTS graph pattern against a solution."""
async for r in evaluate(graph_node, tc, collection, limit):
shared = set(sol.keys()) & set(r.keys())
if all(
_term_key(sol[v]) == _term_key(r[v])
for v in shared
if sol.get(v) is not None and r.get(v) is not None
):
return True
return False
async def _pre_eval_exists(expr, sol, tc, collection, limit, cache):
"""Walk an expression tree, pre-evaluate EXISTS/NOT EXISTS, cache results."""
if not isinstance(expr, CompValue):
return
if expr.name in ("Builtin_EXISTS", "Builtin_NOTEXISTS"):
key = id(expr.graph), id(sol)
if key not in cache:
cache[key] = await _check_exists(
expr.graph, sol, tc, collection, limit
)
return
for attr in ("expr", "other", "arg", "arg1", "arg2", "arg3"):
child = getattr(expr, attr, None)
if child is None:
continue
if isinstance(child, CompValue):
await _pre_eval_exists(child, sol, tc, collection, limit, cache)
elif isinstance(child, (list, tuple)):
for item in child:
if isinstance(item, CompValue):
await _pre_eval_exists(
item, sol, tc, collection, limit, cache
)
async def _eval_filter(node, tc, collection, limit):
expr = node.expr
exists_cache = {}
def exists_cb(graph_node, sol):
key = id(graph_node), id(sol)
return exists_cache.get(key, False)
async for sol in evaluate(node.p, tc, collection, limit):
await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache)
if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)):
yield sol
async def _eval_extend(node, tc, collection, limit): async def _eval_extend(node, tc, collection, limit):
"""Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, collection, limit)
var_name = str(node.var) var_name = str(node.var)
expr = node.expr expr = node.expr
exists_cache = {}
result = [] def exists_cb(graph_node, sol):
for sol in solutions: key = id(graph_node), id(sol)
val = evaluate_expression(expr, sol) return exists_cache.get(key, False)
async for sol in evaluate(node.p, tc, collection, limit):
await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache)
val = evaluate_expression(expr, sol, exists_cb=exists_cb)
new_sol = dict(sol) new_sol = dict(sol)
if isinstance(val, Term): if isinstance(val, Term):
new_sol[var_name] = val new_sol[var_name] = val
@ -240,16 +472,14 @@ async def _eval_extend(node, tc, collection, limit):
) )
elif val is not None: elif val is not None:
new_sol[var_name] = Term(type=LITERAL, value=str(val)) new_sol[var_name] = Term(type=LITERAL, value=str(val))
result.append(new_sol) yield new_sol
return result
# --- Aggregation (blocking) ---
async def _eval_group(node, tc, collection, limit): async def _eval_group(node, tc, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation).""" solutions = await materialise(node.p, tc, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
# Extract grouping expressions
group_exprs = [] group_exprs = []
if hasattr(node, "expr") and node.expr: if hasattr(node, "expr") and node.expr:
for expr in node.expr: for expr in node.expr:
@ -260,7 +490,6 @@ async def _eval_group(node, tc, collection, limit):
else: else:
group_exprs.append((expr, None)) group_exprs.append((expr, None))
# Group solutions
groups = defaultdict(list) groups = defaultdict(list)
for sol in solutions: for sol in solutions:
key_parts = [] key_parts = []
@ -270,81 +499,72 @@ async def _eval_group(node, tc, collection, limit):
groups[tuple(key_parts)].append(sol) groups[tuple(key_parts)].append(sol)
if not group_exprs: if not group_exprs:
# No GROUP BY - entire result is one group
groups[()].extend(solutions) groups[()].extend(solutions)
# Build grouped solutions (one per group)
result = []
for key, group_sols in groups.items(): for key, group_sols in groups.items():
sol = {} sol = {}
# Include group key variables
if group_sols: if group_sols:
for (expr, var_name), k in zip(group_exprs, key): for (expr, var_name), k in zip(group_exprs, key):
if var_name and group_sols: if var_name and group_sols:
sol[var_name] = evaluate_expression(expr, group_sols[0]) sol[var_name] = evaluate_expression(expr, group_sols[0])
sol["__group__"] = group_sols sol["__group__"] = group_sols
result.append(sol) yield sol
return result
async def _eval_aggregate_join(node, tc, collection, limit): async def _eval_aggregate_join(node, tc, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" async for sol in evaluate(node.p, tc, collection, limit):
solutions = await evaluate(node.p, tc, collection, limit)
result = []
for sol in solutions:
group = sol.get("__group__", [sol]) group = sol.get("__group__", [sol])
new_sol = {k: v for k, v in sol.items() if k != "__group__"} new_sol = {k: v for k, v in sol.items() if k != "__group__"}
# Apply aggregate functions
if hasattr(node, "A") and node.A: if hasattr(node, "A") and node.A:
for agg in node.A: for agg in node.A:
var_name = str(agg.res) var_name = str(agg.res)
agg_val = _compute_aggregate(agg, group) agg_val = _compute_aggregate(agg, group)
new_sol[var_name] = agg_val new_sol[var_name] = agg_val
result.append(new_sol) yield new_sol
return result
async def _eval_graph(node, tc, collection, limit): async def _eval_graph(node, tc, collection, limit):
"""Evaluate a Graph node (GRAPH clause)."""
term = node.term term = node.term
if isinstance(term, URIRef): if isinstance(term, URIRef):
# GRAPH <uri> { ... } — fixed graph
# We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, collection, limit)
elif isinstance(term, Variable): elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, collection, limit)
else: async for sol in evaluate(node.p, tc, collection, limit):
return await evaluate(node.p, tc, collection, limit) yield sol
async def _eval_values(node, tc, collection, limit): async def _eval_values(node, tc, collection, limit):
"""Evaluate a VALUES clause (inline data).""" # rdflib has two representations for VALUES:
variables = [str(v) for v in node.var] # 1. var=[Variable...], value=[[val, ...], ...] — positional
solutions = [] # 2. var=None, res=[{Variable: val, ...}, ...] — dict-based
if hasattr(node, "res") and node.res:
for row in node.res:
sol = {}
for var, val in row.items():
if val is not None and str(val) != "UNDEF":
sol[str(var)] = rdflib_term_to_term(val)
yield sol
return
if not node.var or not node.value:
yield {}
return
variables = [str(v) for v in node.var]
for row in node.value: for row in node.value:
sol = {} sol = {}
for var_name, val in zip(variables, row): for var_name, val in zip(variables, row):
if val is not None and str(val) != "UNDEF": if val is not None and str(val) != "UNDEF":
sol[var_name] = rdflib_term_to_term(val) sol[var_name] = rdflib_term_to_term(val)
solutions.append(sol) yield sol
return solutions
async def _eval_to_multiset(node, tc, collection, limit): async def _eval_to_multiset(node, tc, collection, limit):
"""Evaluate a ToMultiSet node (subquery).""" async for sol in evaluate(node.p, tc, collection, limit):
return await evaluate(node.p, tc, collection, limit) yield sol
# --- Aggregate computation --- # --- Aggregate computation ---
@ -353,7 +573,6 @@ def _compute_aggregate(agg, group):
"""Compute a single aggregate function over a group of solutions.""" """Compute a single aggregate function over a group of solutions."""
agg_name = agg.name if hasattr(agg, "name") else "" agg_name = agg.name if hasattr(agg, "name") else ""
# Get the expression to aggregate
expr = agg.vars if hasattr(agg, "vars") else None expr = agg.vars if hasattr(agg, "vars") else None
if agg_name == "Aggregate_Count": if agg_name == "Aggregate_Count":
@ -525,6 +744,7 @@ _HANDLERS = {
"Join": _eval_join, "Join": _eval_join,
"LeftJoin": _eval_left_join, "LeftJoin": _eval_left_join,
"Union": _eval_union, "Union": _eval_union,
"Minus": _eval_minus,
"Filter": _eval_filter, "Filter": _eval_filter,
"Distinct": _eval_distinct, "Distinct": _eval_distinct,
"Reduced": _eval_reduced, "Reduced": _eval_reduced,

View file

@ -5,9 +5,15 @@ Evaluates rdflib algebra expression nodes against a solution (variable
binding) to produce a value or boolean result. binding) to produce a value or boolean result.
""" """
import hashlib
import math
import random
import re import re
import logging import logging
import operator import operator
import uuid
from datetime import datetime, date, timezone
from urllib.parse import quote
from rdflib.term import Variable, URIRef, Literal, BNode from rdflib.term import Variable, URIRef, Literal, BNode
from rdflib.plugins.sparql.parserutils import CompValue from rdflib.plugins.sparql.parserutils import CompValue
@ -17,23 +23,31 @@ from . parser import rdflib_term_to_term
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_exists_callback = None
class ExpressionError(Exception): class ExpressionError(Exception):
"""Raised when a SPARQL expression cannot be evaluated.""" """Raised when a SPARQL expression cannot be evaluated."""
pass pass
def evaluate_expression(expr, solution): def evaluate_expression(expr, solution, exists_cb=None):
""" """
Evaluate a SPARQL expression against a solution binding. Evaluate a SPARQL expression against a solution binding.
Args: Args:
expr: rdflib algebra expression node expr: rdflib algebra expression node
solution: dict mapping variable names to Term values solution: dict mapping variable names to Term values
exists_cb: optional callback(graph_node, solution) -> bool for
EXISTS/NOT EXISTS evaluation; provided by algebra.py
Returns: Returns:
The result value (Term, bool, number, string, or None) The result value (Term, bool, number, string, or None)
""" """
global _exists_callback
if exists_cb is not None:
_exists_callback = exists_cb
if expr is None: if expr is None:
return True return True
@ -111,6 +125,13 @@ def _evaluate_comp_value(node, solution):
if name == "MultiplicativeExpression": if name == "MultiplicativeExpression":
return _eval_multiplicative(node, solution) return _eval_multiplicative(node, solution)
# IN / NOT IN — must be checked before the generic Builtin_ dispatch
if name == "Builtin_IN":
return _eval_in(node, solution)
if name == "Builtin_NOTIN":
return not _eval_in(node, solution)
# SPARQL built-in functions # SPARQL built-in functions
if name.startswith("Builtin_"): if name.startswith("Builtin_"):
return _eval_builtin(name, node, solution) return _eval_builtin(name, node, solution)
@ -119,27 +140,10 @@ def _evaluate_comp_value(node, solution):
if name == "Function": if name == "Function":
return _eval_function(node, solution) return _eval_function(node, solution)
# Exists / NotExists
if name == "Builtin_EXISTS":
# EXISTS requires graph pattern evaluation - not handled here
logger.warning("EXISTS not supported in filter expressions")
return True
if name == "Builtin_NOTEXISTS":
logger.warning("NOT EXISTS not supported in filter expressions")
return True
# TrueFilter (used with OPTIONAL) # TrueFilter (used with OPTIONAL)
if name == "TrueFilter": if name == "TrueFilter":
return True return True
# IN / NOT IN
if name == "Builtin_IN":
return _eval_in(node, solution)
if name == "Builtin_NOTIN":
return not _eval_in(node, solution)
logger.warning(f"Unknown CompValue expression: {name}") logger.warning(f"Unknown CompValue expression: {name}")
return None return None
@ -165,6 +169,22 @@ def _eval_relational(node, solution):
">=": operator.ge, ">=": operator.ge,
} }
if str(op) == "IN":
items = node.other if isinstance(node.other, list) else [node.other]
for item in items:
other_val = evaluate_expression(item, solution)
if _comparable_value(left) == _comparable_value(other_val):
return True
return False
if str(op) == "NOT IN":
items = node.other if isinstance(node.other, list) else [node.other]
for item in items:
other_val = evaluate_expression(item, solution)
if _comparable_value(left) == _comparable_value(other_val):
return False
return True
op_fn = ops.get(str(op)) op_fn = ops.get(str(op))
if op_fn is None: if op_fn is None:
logger.warning(f"Unknown relational operator: {op}") logger.warning(f"Unknown relational operator: {op}")
@ -335,6 +355,197 @@ def _eval_builtin(name, node, solution):
return val return val
return None return None
if builtin == "YEAR":
dt = _to_datetime(evaluate_expression(node.arg, solution))
return dt.year if dt is not None else None
if builtin == "MONTH":
dt = _to_datetime(evaluate_expression(node.arg, solution))
return dt.month if dt is not None else None
if builtin == "DAY":
dt = _to_datetime(evaluate_expression(node.arg, solution))
return dt.day if dt is not None else None
if builtin == "HOURS":
dt = _to_datetime(evaluate_expression(node.arg, solution))
if dt is None:
return None
return dt.hour if isinstance(dt, datetime) else 0
if builtin == "MINUTES":
dt = _to_datetime(evaluate_expression(node.arg, solution))
if dt is None:
return None
return dt.minute if isinstance(dt, datetime) else 0
if builtin == "SECONDS":
dt = _to_datetime(evaluate_expression(node.arg, solution))
if dt is None:
return None
return dt.second if isinstance(dt, datetime) else 0
if builtin == "FLOOR":
val = _to_numeric(evaluate_expression(node.arg, solution))
if val is None:
return None
return int(math.floor(val))
if builtin == "CEIL":
val = _to_numeric(evaluate_expression(node.arg, solution))
if val is None:
return None
return int(math.ceil(val))
if builtin == "ABS":
val = _to_numeric(evaluate_expression(node.arg, solution))
if val is None:
return None
return abs(val)
if builtin == "ROUND":
val = _to_numeric(evaluate_expression(node.arg, solution))
if val is None:
return None
return round(val)
if builtin == "STRBEFORE":
string = _to_string(evaluate_expression(node.arg1, solution))
sep = _to_string(evaluate_expression(node.arg2, solution))
idx = string.find(sep)
if idx < 0:
return Term(type=LITERAL, value="")
return Term(type=LITERAL, value=string[:idx])
if builtin == "STRAFTER":
string = _to_string(evaluate_expression(node.arg1, solution))
sep = _to_string(evaluate_expression(node.arg2, solution))
idx = string.find(sep)
if idx < 0:
return Term(type=LITERAL, value="")
return Term(type=LITERAL, value=string[idx + len(sep):])
if builtin == "ENCODE_FOR_URI":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(type=LITERAL, value=quote(val, safe=""))
if builtin == "REPLACE":
string = _to_string(evaluate_expression(node.arg, solution))
pattern = _to_string(evaluate_expression(node.pattern, solution))
replacement = _to_string(
evaluate_expression(node.replacement, solution)
)
flags_str = ""
if hasattr(node, "flags") and node.flags is not None:
flags_str = _to_string(evaluate_expression(node.flags, solution))
re_flags = 0
if "i" in flags_str:
re_flags |= re.IGNORECASE
if "m" in flags_str:
re_flags |= re.MULTILINE
if "s" in flags_str:
re_flags |= re.DOTALL
try:
result = re.sub(pattern, replacement, string, flags=re_flags)
return Term(type=LITERAL, value=result)
except re.error:
return None
if builtin == "SUBSTR":
string = _to_string(evaluate_expression(node.arg, solution))
start = _to_numeric(evaluate_expression(node.start, solution))
if start is None:
return None
start_idx = max(int(start) - 1, 0)
if hasattr(node, "length") and node.length is not None:
length = _to_numeric(evaluate_expression(node.length, solution))
if length is None:
return None
return Term(
type=LITERAL, value=string[start_idx:start_idx + int(length)]
)
return Term(type=LITERAL, value=string[start_idx:])
if builtin == "EXISTS":
if _exists_callback is not None:
return _exists_callback(node.graph, solution)
logger.warning("EXISTS requires an exists_cb; not available")
return True
if builtin == "NOTEXISTS":
if _exists_callback is not None:
return not _exists_callback(node.graph, solution)
logger.warning("NOT EXISTS requires an exists_cb; not available")
return True
if builtin == "LANGMATCHES":
tag = _to_string(evaluate_expression(node.arg1, solution))
rng = _to_string(evaluate_expression(node.arg2, solution))
if rng == "*":
return len(tag) > 0
return tag.lower().startswith(rng.lower())
if builtin == "IRI" or builtin == "URI":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(type=IRI, iri=val)
if builtin == "BNODE":
if hasattr(node, "arg") and node.arg is not None:
label = _to_string(evaluate_expression(node.arg, solution))
return Term(type=BLANK, id=label)
return Term(type=BLANK, id=str(uuid.uuid4()))
if builtin == "NOW":
now = datetime.now(timezone.utc)
return Term(
type=LITERAL,
value=now.strftime("%Y-%m-%dT%H:%M:%S%z"),
datatype="http://www.w3.org/2001/XMLSchema#dateTime",
)
if builtin == "TZ":
dt = _to_datetime(evaluate_expression(node.arg, solution))
if dt is None:
return Term(type=LITERAL, value="")
if dt.tzinfo is not None:
offset = dt.strftime("%z")
if offset:
return Term(type=LITERAL, value=offset[:3] + ":" + offset[3:])
return Term(type=LITERAL, value="")
if builtin == "RAND":
return random.random()
if builtin == "UUID":
return Term(type=IRI, iri="urn:uuid:" + str(uuid.uuid4()))
if builtin == "STRUUID":
return Term(type=LITERAL, value=str(uuid.uuid4()))
if builtin == "MD5":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(
type=LITERAL, value=hashlib.md5(val.encode()).hexdigest()
)
if builtin == "SHA1":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(
type=LITERAL, value=hashlib.sha1(val.encode()).hexdigest()
)
if builtin == "SHA256":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(
type=LITERAL, value=hashlib.sha256(val.encode()).hexdigest()
)
if builtin == "SHA512":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(
type=LITERAL, value=hashlib.sha512(val.encode()).hexdigest()
)
if builtin == "sameTerm": if builtin == "sameTerm":
left = evaluate_expression(node.arg1, solution) left = evaluate_expression(node.arg1, solution)
right = evaluate_expression(node.arg2, solution) right = evaluate_expression(node.arg2, solution)
@ -454,6 +665,27 @@ def _to_numeric(val):
return None return None
def _to_datetime(val):
"""Convert a value to a date or datetime object."""
if val is None:
return None
s = _to_string(val)
for fmt in (
"%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z",
"%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M",
"%Y-%m-%d",
):
try:
return datetime.strptime(s, fmt)
except ValueError:
continue
try:
return datetime.fromisoformat(s)
except ValueError:
pass
return None
def _comparable_value(val): def _comparable_value(val):
""" """
Convert a value to a form suitable for comparison. Convert a value to a form suitable for comparison.

View file

@ -12,7 +12,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import TriplesClientSpec from ... base import TriplesClientSpec
from . parser import parse_sparql, ParseError from . parser import parse_sparql, ParseError
from . algebra import evaluate, EvaluationError from . algebra import evaluate, materialise, EvaluationError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,11 +66,10 @@ class Processor(FlowProcessor):
logger.debug(f"Handling SPARQL query request {id}...") logger.debug(f"Handling SPARQL query request {id}...")
response = await self.execute_sparql(request, flow) if request.streaming:
await self.execute_sparql_streaming(request, flow, id)
if request.streaming and response.query_type == "select":
await self.send_streaming(response, flow, id, request)
else: else:
response = await self.execute_sparql(request, flow)
await flow("response").send( await flow("response").send(
response, properties={"id": id} response, properties={"id": id}
) )
@ -92,37 +91,77 @@ class Processor(FlowProcessor):
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
async def send_streaming(self, response, flow, id, request): async def execute_sparql_streaming(self, request, flow, id):
"""Send SELECT results in batches.""" """Execute a SPARQL query and stream results as they arrive."""
bindings = response.bindings try:
parsed = parse_sparql(request.query)
except ParseError as e:
await flow("response").send(
SparqlQueryResponse(
error=Error(
type="sparql-parse-error",
message=str(e),
),
),
properties={"id": id}
)
return
if parsed.query_type != "select":
response = await self._execute_non_select(parsed, request, flow)
await flow("response").send(response, properties={"id": id})
return
triples_client = flow("triples-request")
variables = parsed.variables
batch_size = request.batch_size if request.batch_size > 0 else 20 batch_size = request.batch_size if request.batch_size > 0 else 20
batch = []
for i in range(0, len(bindings), batch_size): try:
batch = bindings[i:i + batch_size] async for sol in evaluate(
is_final = (i + batch_size >= len(bindings)) parsed.algebra,
r = SparqlQueryResponse( triples_client,
query_type=response.query_type, collection=request.collection or "default",
variables=response.variables, limit=request.limit or 10000,
bindings=batch, ):
is_final=is_final, values = [sol.get(v) for v in variables]
) batch.append(SparqlBinding(values=values))
await flow("response").send(r, properties={"id": id})
# Handle empty results if len(batch) >= batch_size:
if len(bindings) == 0: r = SparqlQueryResponse(
r = SparqlQueryResponse( query_type="select",
query_type=response.query_type, variables=variables,
variables=response.variables, bindings=batch,
bindings=[], is_final=False,
is_final=True, )
await flow("response").send(r, properties={"id": id})
batch = []
except EvaluationError as e:
await flow("response").send(
SparqlQueryResponse(
error=Error(
type="sparql-evaluation-error",
message=str(e),
),
),
properties={"id": id}
) )
await flow("response").send(r, properties={"id": id}) return
# Final batch (may be empty for zero results)
r = SparqlQueryResponse(
query_type="select",
variables=variables,
bindings=batch,
is_final=True,
)
await flow("response").send(r, properties={"id": id})
async def execute_sparql(self, request, flow): async def execute_sparql(self, request, flow):
"""Parse and evaluate a SPARQL query.""" """Parse and evaluate a SPARQL query (non-streaming)."""
# Parse the SPARQL query
try: try:
parsed = parse_sparql(request.query) parsed = parse_sparql(request.query)
except ParseError as e: except ParseError as e:
@ -133,12 +172,31 @@ class Processor(FlowProcessor):
), ),
) )
# Get the triples client from the flow if parsed.query_type == "select":
triples_client = flow("triples-request") triples_client = flow("triples-request")
try:
solutions = await materialise(
parsed.algebra,
triples_client,
collection=request.collection or "default",
limit=request.limit or 10000,
)
except EvaluationError as e:
return SparqlQueryResponse(
error=Error(
type="sparql-evaluation-error",
message=str(e),
),
)
return self._build_select_response(parsed, solutions)
# Evaluate the algebra return await self._execute_non_select(parsed, request, flow)
async def _execute_non_select(self, parsed, request, flow):
"""Execute ASK, CONSTRUCT, or DESCRIBE queries."""
triples_client = flow("triples-request")
try: try:
solutions = await evaluate( solutions = await materialise(
parsed.algebra, parsed.algebra,
triples_client, triples_client,
collection=request.collection or "default", collection=request.collection or "default",
@ -152,10 +210,7 @@ class Processor(FlowProcessor):
), ),
) )
# Build response based on query type if parsed.query_type == "ask":
if parsed.query_type == "select":
return self._build_select_response(parsed, solutions)
elif parsed.query_type == "ask":
return self._build_ask_response(solutions) return self._build_ask_response(solutions)
elif parsed.query_type == "construct": elif parsed.query_type == "construct":
return self._build_construct_response(parsed, solutions) return self._build_construct_response(parsed, solutions)

View file

@ -150,6 +150,30 @@ def left_join(left, right, filter_fn=None):
return results return results
def minus(left, right):
"""
MINUS operation: remove left solutions that are compatible with any
right solution sharing at least one variable.
"""
if not right:
return list(left)
right_vars = set()
for sol in right:
right_vars.update(sol.keys())
results = []
for sol_l in left:
shared = set(sol_l.keys()) & right_vars
if not shared:
results.append(sol_l)
continue
if not any(_compatible(sol_l, sol_r) for sol_r in right):
results.append(sol_l)
return results
def union(left, right): def union(left, right):
"""Union two solution sequences (concatenation).""" """Union two solution sequences (concatenation)."""
return list(left) + list(right) return list(left) + list(right)
@ -177,6 +201,28 @@ def distinct(solutions):
return results return results
def _sort_comparable(val):
"""Convert a value to a form suitable for sort ordering."""
if val is None:
return (0, "")
if isinstance(val, (int, float)):
return (2, val)
if isinstance(val, Term):
if val.type == LITERAL:
try:
if "." in val.value:
return (2, float(val.value))
return (2, int(val.value))
except (ValueError, TypeError):
pass
return (3, val.value)
elif val.type == IRI:
return (4, val.iri)
elif val.type == BLANK:
return (5, val.id)
return (6, str(val))
def order_by(solutions, key_fns): def order_by(solutions, key_fns):
""" """
Sort solutions by the given key functions. Sort solutions by the given key functions.
@ -191,14 +237,7 @@ def order_by(solutions, key_fns):
keys = [] keys = []
for fn, ascending in key_fns: for fn, ascending in key_fns:
val = fn(sol) val = fn(sol)
# Convert to comparable form keys.append(_sort_comparable(val))
if val is None:
comparable = ("", "")
elif isinstance(val, Term):
comparable = _term_key(val)
else:
comparable = ("v", str(val))
keys.append(comparable)
return keys return keys
# Handle ascending/descending # Handle ascending/descending
@ -224,10 +263,8 @@ def _mixed_sort(solutions, key_fns):
def compare(a, b): def compare(a, b):
for fn, ascending in key_fns: for fn, ascending in key_fns:
va = fn(a) ka = _sort_comparable(fn(a))
vb = fn(b) kb = _sort_comparable(fn(b))
ka = _term_key(va) if isinstance(va, Term) else ("v", str(va)) if va is not None else ("", "")
kb = _term_key(vb) if isinstance(vb, Term) else ("v", str(vb)) if vb is not None else ("", "")
if ka < kb: if ka < kb:
return -1 if ascending else 1 return -1 if ascending else 1

View file

@ -1,130 +1,140 @@
import asyncio import asyncio
import logging import logging
import uuid import uuid
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, Callable, Awaitable
from trustgraph.messaging import TranslatorRegistry
from ..gateway.dispatch.manager import DispatcherManager from ..gateway.dispatch.manager import DispatcherManager
logger = logging.getLogger("dispatcher") logger = logging.getLogger("dispatcher")
logger.setLevel(logging.INFO)
class WebSocketResponder:
"""Simple responder that captures response for websocket return""" class _TokenShim:
def __init__(self): def __init__(self, token):
self.response = None self.headers = (
self.completed = False {"Authorization": f"Bearer {token}"} if token else {}
)
async def send(self, data):
"""Capture the response data"""
self.response = data
self.completed = True
async def __call__(self, data, final=False):
"""Make the responder callable for compatibility with requestor"""
await self.send(data)
if final:
self.completed = True
class MessageDispatcher: class MessageDispatcher:
def __init__(self, max_workers: int = 10, config_receiver=None, backend=None): def __init__(self, max_workers=10, config_receiver=None, backend=None,
auth=None, timeout=120):
self.max_workers = max_workers self.max_workers = max_workers
self.semaphore = asyncio.Semaphore(max_workers) self.semaphore = asyncio.Semaphore(max_workers)
self.active_tasks = set() self.active_tasks = set()
self.backend = backend self.backend = backend
self.auth = auth
# Use DispatcherManager for flow and service management if backend and config_receiver and auth:
if backend and config_receiver: self.dispatcher_manager = DispatcherManager(
self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway") backend, config_receiver,
auth=auth,
prefix="rev-gateway",
timeout=timeout,
)
else: else:
self.dispatcher_manager = None self.dispatcher_manager = None
logger.warning("No backend or config_receiver provided - using fallback mode") logger.warning(
"Missing backend, config_receiver, or auth "
# Service name mapping from websocket protocol to translator registry "— using fallback mode"
)
self.service_mapping = { self.service_mapping = {
"text-completion": "text-completion", "text-completion": "text-completion",
"graph-rag": "graph-rag", "graph-rag": "graph-rag",
"agent": "agent", "agent": "agent",
"embeddings": "embeddings", "embeddings": "embeddings",
"graph-embeddings": "graph-embeddings", "graph-embeddings": "graph-embeddings",
"triples": "triples", "triples": "triples",
"document-load": "document", "document-load": "document",
"text-load": "text-document", "text-load": "text-document",
"flow": "flow", "flow": "flow",
"knowledge": "knowledge", "knowledge": "knowledge",
"config": "config", "config": "config",
"librarian": "librarian", "librarian": "librarian",
"document-rag": "document-rag" "document-rag": "document-rag",
} }
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]: async def handle_message(
self, message: Dict[Any, Any],
sender: Callable[[dict], Awaitable[None]],
):
async with self.semaphore: async with self.semaphore:
task = asyncio.create_task(self._process_message(message)) task = asyncio.create_task(
self._process_message(message, sender)
)
self.active_tasks.add(task) self.active_tasks.add(task)
try: try:
result = await task await task
return result
finally: finally:
self.active_tasks.discard(task) self.active_tasks.discard(task)
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]: async def _authenticate(self, token):
if not self.auth:
raise RuntimeError("Auth not configured")
return await self.auth.authenticate(_TokenShim(token))
async def _process_message(
self, message: Dict[Any, Any],
sender: Callable[[dict], Awaitable[None]],
):
request_id = message.get('id', str(uuid.uuid4())) request_id = message.get('id', str(uuid.uuid4()))
service = message.get('service') service = message.get('service')
request_data = message.get('request', {}) request_data = message.get('request', {})
flow_id = message.get('flow', 'default') # Default flow token = message.get('token', '')
flow_id = message.get('flow', 'default')
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
logger.info(
f"Processing message {request_id} for service "
f"{service} on flow {flow_id}"
)
try: try:
if not self.dispatcher_manager: if not self.dispatcher_manager:
raise RuntimeError("DispatcherManager not available - backend and config_receiver required") raise RuntimeError(
"DispatcherManager not available"
# Use DispatcherManager for flow-based processing )
responder = WebSocketResponder()
identity = await self._authenticate(token)
# Map websocket service name to dispatcher service name workspace = identity.workspace
async def responder(resp, fin):
await sender({
"id": request_id,
"response": resp,
"complete": fin,
})
dispatcher_service = self.service_mapping.get(service, service) dispatcher_service = self.service_mapping.get(service, service)
# Check if this is a global service or flow service
from ..gateway.dispatch.manager import global_dispatchers from ..gateway.dispatch.manager import global_dispatchers
if dispatcher_service in global_dispatchers: if dispatcher_service in global_dispatchers:
# Use global service dispatcher
await self.dispatcher_manager.invoke_global_service( await self.dispatcher_manager.invoke_global_service(
request_data, responder, dispatcher_service request_data, responder, dispatcher_service,
workspace=workspace,
) )
else: else:
# Use DispatcherManager to process the request through Pulsar queues
await self.dispatcher_manager.invoke_flow_service( await self.dispatcher_manager.invoke_flow_service(
request_data, responder, flow_id, dispatcher_service request_data, responder, workspace, flow_id,
dispatcher_service,
) )
# Get the response from the responder
if responder.completed:
response_data = responder.response
else:
response_data = {'error': 'No response received'}
response = {
'id': request_id,
'response': response_data
}
except Exception as e: except Exception as e:
logger.error(f"Error processing message {request_id}: {e}") logger.error(f"Error processing message {request_id}: {e}")
response = { await sender({
'id': request_id, "id": request_id,
'response': {'error': str(e)} "error": {"message": str(e), "type": "error"},
} "complete": True,
})
logger.info(f"Completed processing message {request_id}") logger.info(f"Completed processing message {request_id}")
return response
async def shutdown(self): async def shutdown(self):
if self.active_tasks: if self.active_tasks:
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete") logger.info(
f"Waiting for {len(self.active_tasks)} active "
f"tasks to complete"
)
await asyncio.gather(*self.active_tasks, return_exceptions=True) await asyncio.gather(*self.active_tasks, return_exceptions=True)
# DispatcherManager handles its own cleanup
logger.info("Dispatcher shutdown complete") logger.info("Dispatcher shutdown complete")

View file

@ -1,3 +1,9 @@
"""
Reverse gateway. Initiates outbound WebSocket connections to a remote
relay and dispatches incoming requests through the same DispatcherManager
pipeline as api-gateway.
"""
import asyncio import asyncio
import argparse import argparse
import logging import logging
@ -9,82 +15,86 @@ from typing import Optional
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from .dispatcher import MessageDispatcher from .dispatcher import MessageDispatcher
from ..gateway.auth import IamAuth
from ..gateway.config.receiver import ConfigReceiver from ..gateway.config.receiver import ConfigReceiver
from ..base import get_pubsub from ..base.pubsub import get_pubsub, add_pubsub_args
from ..base.logging import setup_logging, add_logging_args
logger = logging.getLogger("rev_gateway") logger = logging.getLogger("rev_gateway")
logger.setLevel(logging.INFO)
default_websocket = "ws://localhost:7650/out" default_websocket = "ws://localhost:7650/out"
default_timeout = 600
class ReverseGateway: class ReverseGateway:
def __init__(self, websocket_uri: str = None, max_workers: int = 10, def __init__(self, **config):
pulsar_host: str = None, pulsar_api_key: str = None, websocket_uri = config.get("websocket_uri")
pulsar_listener: str = None):
# Set default WebSocket URI with environment variable support
if websocket_uri is None: if websocket_uri is None:
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket) websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
# Parse and validate the WebSocket URI
parsed_uri = urlparse(websocket_uri) parsed_uri = urlparse(websocket_uri)
if parsed_uri.scheme not in ('ws', 'wss'): if parsed_uri.scheme not in ('ws', 'wss'):
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}") raise ValueError(
f"WebSocket URI must use ws:// or wss:// scheme, "
f"got: {parsed_uri.scheme}"
)
if not parsed_uri.netloc: if not parsed_uri.netloc:
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}") raise ValueError(
f"WebSocket URI must include hostname, "
# Store parsed components for debugging/logging f"got: {websocket_uri}"
)
self.websocket_uri = websocket_uri self.websocket_uri = websocket_uri
self.host = parsed_uri.hostname self.host = parsed_uri.hostname
self.port = parsed_uri.port self.port = parsed_uri.port
self.scheme = parsed_uri.scheme self.scheme = parsed_uri.scheme
self.path = parsed_uri.path or "/ws" self.path = parsed_uri.path or "/ws"
# Construct the full URL (in case path was missing)
if not parsed_uri.path: if not parsed_uri.path:
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws" self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
else: else:
self.url = websocket_uri self.url = websocket_uri
self.max_workers = max_workers self.max_workers = int(config.get("max_workers", 10))
self.timeout = int(config.get("timeout", default_timeout))
self.ws: Optional[ClientWebSocketResponse] = None self.ws: Optional[ClientWebSocketResponse] = None
self.session: Optional[ClientSession] = None self.session: Optional[ClientSession] = None
self.running = False self.running = False
self.reconnect_delay = 3.0 self.reconnect_delay = 3.0
# Pulsar configuration
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
self.pulsar_listener = pulsar_listener
# Create backend using factory self.backend = get_pubsub(**config)
backend_params = {
'pulsar_host': self.pulsar_host,
'pulsar_api_key': self.pulsar_api_key,
'pulsar_listener': self.pulsar_listener,
}
self.backend = get_pubsub(**backend_params)
# Initialize config receiver self.auth = IamAuth(
self.config_receiver = ConfigReceiver(self.backend) backend=self.backend,
id=config.get("id", "rev-gateway"),
)
self.config_receiver = ConfigReceiver(
self.backend, auth=self.auth,
)
self.dispatcher = MessageDispatcher(
self.max_workers, self.config_receiver, self.backend,
auth=self.auth, timeout=self.timeout,
)
# Initialize dispatcher with config_receiver and backend - must be created after config_receiver
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend)
async def connect(self) -> bool: async def connect(self) -> bool:
try: try:
if self.session is None: if self.session is None:
self.session = ClientSession() self.session = ClientSession()
logger.info(f"Connecting to {self.url}") logger.info(f"Connecting to {self.url}")
self.ws = await self.session.ws_connect(self.url) self.ws = await self.session.ws_connect(self.url)
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}") logger.info(
f"WebSocket connection established to "
f"{self.host}:{self.port or 'default'}"
)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to {self.url}: {e}") logger.error(f"Failed to connect to {self.url}: {e}")
return False return False
async def disconnect(self): async def disconnect(self):
if self.ws and not self.ws.closed: if self.ws and not self.ws.closed:
await self.ws.close() await self.ws.close()
@ -92,32 +102,31 @@ class ReverseGateway:
await self.session.close() await self.session.close()
self.ws = None self.ws = None
self.session = None self.session = None
async def send_message(self, message: dict): async def send_message(self, message: dict):
if self.ws and not self.ws.closed: if self.ws and not self.ws.closed:
try: try:
await self.ws.send_str(json.dumps(message)) await self.ws.send_str(json.dumps(message))
except Exception as e: except Exception as e:
logger.error(f"Failed to send message: {e}") logger.error(f"Failed to send message: {e}")
async def handle_message(self, message: str): async def handle_message(self, message: str):
try: try:
logger.debug(f"Received message: {message}") logger.debug(f"Received message: {message}")
msg_data = json.loads(message) msg_data = json.loads(message)
response = await self.dispatcher.handle_message(msg_data) await self.dispatcher.handle_message(
msg_data, self.send_message,
if response: )
await self.send_message(response)
except Exception as e: except Exception as e:
logger.error(f"Error handling message: {e}") logger.error(f"Error handling message: {e}")
async def listen(self): async def listen(self):
while self.running and self.ws and not self.ws.closed: while self.running and self.ws and not self.ws.closed:
try: try:
msg = await self.ws.receive() msg = await self.ws.receive()
if msg.type == WSMsgType.TEXT: if msg.type == WSMsgType.TEXT:
await self.handle_message(msg.data) await self.handle_message(msg.data)
elif msg.type == WSMsgType.BINARY: elif msg.type == WSMsgType.BINARY:
@ -125,31 +134,33 @@ class ReverseGateway:
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
logger.warning("WebSocket closed or error occurred") logger.warning("WebSocket closed or error occurred")
break break
except Exception as e: except Exception as e:
logger.error(f"Error in listen loop: {e}") logger.error(f"Error in listen loop: {e}")
break break
async def run(self): async def run(self):
self.running = True self.running = True
logger.info("Starting reverse gateway") logger.info("Starting reverse gateway")
# Start config receiver await self.auth.start()
logger.info("Starting config receiver")
await self.config_receiver.start() await self.config_receiver.start()
while self.running: while self.running:
try: try:
if await self.connect(): if await self.connect():
await self.listen() await self.listen()
else: else:
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds") logger.warning(
f"Connection failed, retrying in "
f"{self.reconnect_delay} seconds"
)
await self.disconnect() await self.disconnect()
if self.running: if self.running:
await asyncio.sleep(self.reconnect_delay) await asyncio.sleep(self.reconnect_delay)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Shutdown requested") logger.info("Shutdown requested")
break break
@ -157,77 +168,69 @@ class ReverseGateway:
logger.error(f"Unexpected error: {e}") logger.error(f"Unexpected error: {e}")
if self.running: if self.running:
await asyncio.sleep(self.reconnect_delay) await asyncio.sleep(self.reconnect_delay)
await self.shutdown() await self.shutdown()
async def shutdown(self): async def shutdown(self):
logger.info("Shutting down reverse gateway") logger.info("Shutting down reverse gateway")
self.running = False self.running = False
await self.dispatcher.shutdown() await self.dispatcher.shutdown()
await self.disconnect() await self.disconnect()
# Close backend
if hasattr(self, 'backend'): if hasattr(self, 'backend'):
self.backend.close() self.backend.close()
def stop(self): def stop(self):
self.running = False self.running = False
def parse_args():
def run():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="reverse-gateway", prog="reverse-gateway",
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge" description=__doc__,
) )
parser.add_argument(
'--id',
default='rev-gateway',
help='Service identifier (default: rev-gateway)',
)
parser.add_argument( parser.add_argument(
'--websocket-uri', '--websocket-uri',
default=None, default=None,
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)' help=f'WebSocket URI to connect to (default: {default_websocket})',
) )
parser.add_argument( parser.add_argument(
'--max-workers', '--max-workers',
type=int, type=int,
default=10, default=10,
help='Maximum concurrent message handlers (default: 10)' help='Maximum concurrent message handlers (default: 10)',
) )
parser.add_argument(
'-p', '--pulsar-host',
default=None,
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
)
parser.add_argument(
'--pulsar-api-key',
default=None,
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
)
parser.add_argument(
'--pulsar-listener',
default=None,
help='Pulsar listener name'
)
return parser.parse_args()
def run(): parser.add_argument(
args = parse_args() '--timeout',
type=int,
gateway = ReverseGateway( default=default_timeout,
websocket_uri=args.websocket_uri, help=f'Request timeout in seconds (default: {default_timeout})',
max_workers=args.max_workers,
pulsar_host=args.pulsar_host,
pulsar_api_key=args.pulsar_api_key,
pulsar_listener=args.pulsar_listener
) )
add_pubsub_args(parser)
add_logging_args(parser)
args = parser.parse_args()
args = vars(args)
setup_logging(args)
gateway = ReverseGateway(**args)
logger.info(f"Starting reverse gateway:") logger.info(f"Starting reverse gateway:")
logger.info(f" WebSocket URI: {gateway.url}") logger.info(f" WebSocket URI: {gateway.url}")
logger.info(f" Max workers: {args.max_workers}") logger.info(f" Max workers: {gateway.max_workers}")
logger.info(f" Pulsar host: {gateway.pulsar_host}")
try: try:
asyncio.run(gateway.run()) asyncio.run(gateway.run())
except KeyboardInterrupt: except KeyboardInterrupt: