Merge pull request #953 from trustgraph-ai/release/v2.5

release/v2.5 -> master
This commit is contained in:
cybermaggedon 2026-05-26 15:01:44 +01:00 committed by GitHub
commit 36eadbda3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 2632 additions and 1140 deletions

View file

@ -25,7 +25,7 @@ BUCKET_URL = "https://storage.googleapis.com/trustgraph-library"
INDEX_URL = f"{BUCKET_URL}/index.json"
default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
default_user = "trustgraph"
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@ -113,7 +113,7 @@ def convert_metadata(metadata_json):
return triples
def load_document(api, user, doc_entry):
def load_document(api, doc_entry):
"""Fetch metadata and content for a document, then load into TrustGraph."""
doc_id = doc_entry["id"]
title = doc_entry["title"]
@ -133,7 +133,6 @@ def load_document(api, user, doc_entry):
api.add_document(
id=doc["id"],
metadata=metadata,
user=user,
kind=doc["kind"],
title=doc["title"],
comments=doc["comments"],
@ -144,12 +143,12 @@ def load_document(api, user, doc_entry):
print(f" done.")
def load_documents(api, user, docs):
def load_documents(api, docs):
"""Load a list of documents."""
print(f"Loading {len(docs)} document(s)...\n")
for doc in docs:
try:
load_document(api, user, doc)
load_document(api, doc)
except Exception as e:
print(f" FAILED: {e}", file=sys.stderr)
print()
@ -166,8 +165,8 @@ def main():
help=f"TrustGraph API URL (default: {default_url})",
)
parser.add_argument(
"-U", "--user", default=default_user,
help=f"User ID (default: {default_user})",
"-w", "--workspace", default=default_workspace,
help=f"Workspace (default: {default_workspace})",
)
parser.add_argument(
"-t", "--token", default=default_token,
@ -212,22 +211,22 @@ def main():
return
# Load commands need the API
api = Api(args.url, token=args.token).library()
api = Api(args.url, token=args.token, workspace=args.workspace).library()
if args.command == "load-all":
load_documents(api, args.user, index)
load_documents(api, index)
elif args.command == "load-doc":
matches = [d for d in index if str(d.get("id")) == args.id]
if not matches:
print(f"No document with ID '{args.id}' found.", file=sys.stderr)
sys.exit(1)
load_documents(api, args.user, matches)
load_documents(api, matches)
elif args.command == "load-match":
results = search_index(index, args.query)
if results:
load_documents(api, args.user, results)
load_documents(api, results)
else:
print("No matches found.", file=sys.stderr)
sys.exit(1)

View file

@ -3,208 +3,278 @@
WebSocket Relay Test Harness
This script creates a relay server with two WebSocket endpoints:
- /in - for test clients to connect to
- /out - for reverse gateway to connect to
- /in - for test clients to connect to (speaks api-gateway protocol)
- /out - for reverse gateway to connect to (speaks rev-gateway protocol)
Messages are bidirectionally relayed between the two connections.
Clients on /in authenticate with a first-frame auth message:
{"type": "auth", "token": "..."}
The relay stores the token and injects it into each subsequent message
before forwarding to /out. Responses from /out are forwarded back to
the originating /in connection unchanged.
Usage:
python websocket_relay.py [--port PORT] [--host HOST]
"""
import asyncio
import json
import logging
import argparse
from aiohttp import web, WSMsgType
import weakref
from typing import Optional, Set
from typing import Dict, Optional
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("websocket_relay")
class InConnection:
def __init__(self, ws, conn_id):
self.ws = ws
self.conn_id = conn_id
self.token: Optional[str] = None
self.authenticated = False
class WebSocketRelay:
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
def __init__(self):
self.in_connections: Set = weakref.WeakSet()
self.out_connections: Set = weakref.WeakSet()
self.in_connections: Dict[str, InConnection] = {}
self.out_connections: set = set()
self._conn_counter = 0
def _next_conn_id(self):
self._conn_counter += 1
return f"conn-{self._conn_counter}"
async def handle_in_connection(self, request):
"""Handle incoming connections on /in endpoint"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.in_connections.add(ws)
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
conn_id = self._next_conn_id()
conn = InConnection(ws, conn_id)
self.in_connections[conn_id] = conn
logger.info(
f"New 'in' connection {conn_id}. "
f"Total in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
data = msg.data
logger.info(f"IN → OUT: {data}")
await self._forward_to_out(data)
elif msg.type == WSMsgType.BINARY:
data = msg.data
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
await self._forward_to_out(data, binary=True)
await self._handle_in_message(conn, msg.data)
elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}")
logger.error(
f"WebSocket error on 'in' connection "
f"{conn_id}: {ws.exception()}"
)
break
else:
break
except Exception as e:
logger.error(f"Error in 'in' connection handler: {e}")
logger.error(
f"Error in 'in' connection {conn_id}: {e}"
)
finally:
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
del self.in_connections[conn_id]
logger.info(
f"'in' connection {conn_id} closed. "
f"Remaining in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
return ws
async def handle_out_connection(self, request):
"""Handle outgoing connections on /out endpoint"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.out_connections.add(ws)
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
async def _handle_in_message(self, conn, data):
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
data = msg.data
logger.info(f"OUT → IN: {data}")
await self._forward_to_in(data)
elif msg.type == WSMsgType.BINARY:
data = msg.data
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
await self._forward_to_in(data, binary=True)
elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
break
else:
break
except Exception as e:
logger.error(f"Error in 'out' connection handler: {e}")
finally:
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
return ws
async def _forward_to_out(self, data, binary=False):
"""Forward message from 'in' to all 'out' connections"""
if not self.out_connections:
logger.warning("No 'out' connections available to forward message")
message = json.loads(data)
except json.JSONDecodeError:
logger.warning(
f"{conn.conn_id}: received non-JSON message"
)
return
closed_connections = []
if isinstance(message, dict) and message.get("type") == "auth":
conn.token = message.get("token", "")
conn.authenticated = True
logger.info(f"{conn.conn_id}: authenticated")
await conn.ws.send_json({
"type": "auth-ok",
"workspace": "relayed",
})
return
if not conn.authenticated:
await conn.ws.send_json({
"error": {
"message": "auth required",
"type": "auth-required",
},
"complete": True,
})
return
message["token"] = conn.token
message["_relay_conn"] = conn.conn_id
forwarded = json.dumps(message)
logger.info(f"IN {conn.conn_id} → OUT: {forwarded}")
await self._forward_to_out(forwarded)
async def handle_out_connection(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.out_connections.add(ws)
logger.info(
f"New 'out' connection. "
f"Total in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
await self._handle_out_message(msg.data)
elif msg.type == WSMsgType.ERROR:
logger.error(
f"WebSocket error on 'out' connection: "
f"{ws.exception()}"
)
break
else:
break
except Exception as e:
logger.error(f"Error in 'out' connection: {e}")
finally:
self.out_connections.discard(ws)
logger.info(
f"'out' connection closed. "
f"Remaining in: {len(self.in_connections)}, "
f"out: {len(self.out_connections)}"
)
return ws
async def _handle_out_message(self, data):
try:
message = json.loads(data)
except json.JSONDecodeError:
logger.warning("OUT: received non-JSON message")
return
conn_id = message.pop("_relay_conn", None)
forwarded = json.dumps(message)
logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}")
if conn_id and conn_id in self.in_connections:
conn = self.in_connections[conn_id]
try:
if not conn.ws.closed:
await conn.ws.send_str(forwarded)
except Exception as e:
logger.error(
f"Error forwarding to 'in' {conn_id}: {e}"
)
else:
await self._broadcast_to_in(forwarded)
async def _broadcast_to_in(self, data):
closed = []
for conn_id, conn in list(self.in_connections.items()):
try:
if conn.ws.closed:
closed.append(conn_id)
continue
await conn.ws.send_str(data)
except Exception as e:
logger.error(
f"Error broadcasting to 'in' {conn_id}: {e}"
)
closed.append(conn_id)
for conn_id in closed:
self.in_connections.pop(conn_id, None)
async def _forward_to_out(self, data):
closed = []
for ws in list(self.out_connections):
try:
if ws.closed:
closed_connections.append(ws)
closed.append(ws)
continue
if binary:
await ws.send_bytes(data)
else:
await ws.send_str(data)
await ws.send_str(data)
except Exception as e:
logger.error(f"Error forwarding to 'out' connection: {e}")
closed_connections.append(ws)
# Clean up closed connections
for ws in closed_connections:
if ws in self.out_connections:
self.out_connections.discard(ws)
async def _forward_to_in(self, data, binary=False):
"""Forward message from 'out' to all 'in' connections"""
if not self.in_connections:
logger.warning("No 'in' connections available to forward message")
return
closed_connections = []
for ws in list(self.in_connections):
try:
if ws.closed:
closed_connections.append(ws)
continue
if binary:
await ws.send_bytes(data)
else:
await ws.send_str(data)
except Exception as e:
logger.error(f"Error forwarding to 'in' connection: {e}")
closed_connections.append(ws)
# Clean up closed connections
for ws in closed_connections:
if ws in self.in_connections:
self.in_connections.discard(ws)
logger.error(f"Error forwarding to 'out': {e}")
closed.append(ws)
for ws in closed:
self.out_connections.discard(ws)
async def create_app(relay):
"""Create the web application with routes"""
app = web.Application()
# Add routes
app.router.add_get('/in', relay.handle_in_connection)
app.router.add_get('/in/api/v1/socket', relay.handle_in_connection)
app.router.add_get('/out', relay.handle_out_connection)
# Add a simple status endpoint
async def status(request):
status_info = {
return web.json_response({
'in_connections': len(relay.in_connections),
'out_connections': len(relay.out_connections),
'status': 'running'
}
return web.json_response(status_info)
'status': 'running',
})
app.router.add_get('/status', status)
app.router.add_get('/', status) # Root also shows status
app.router.add_get('/', status)
return app
def main():
parser = argparse.ArgumentParser(
description="WebSocket Relay Test Harness"
)
parser.add_argument(
'--host',
'--host',
default='localhost',
help='Host to bind to (default: localhost)'
help='Host to bind to (default: localhost)',
)
parser.add_argument(
'--port',
type=int,
'--port',
type=int,
default=8080,
help='Port to bind to (default: 8080)'
help='Port to bind to (default: 8080)',
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Enable verbose logging'
help='Enable verbose logging',
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
relay = WebSocketRelay()
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in")
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket")
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
print(f" Status: http://{args.host}:{args.port}/status")
print()
print("Usage:")
print(f" Test client connects to: ws://{args.host}:{args.port}/in")
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out")
print("Client protocol (same as api-gateway):")
print(' 1. Connect to /in/api/v1/socket')
print(' 2. Send: {"type": "auth", "token": "tg_..."}')
print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}')
print(' 4. Send requests as normal')
web.run_app(create_app(relay), host=args.host, port=args.port)
if __name__ == "__main__":
main()
main()

View file

@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
max_concurrent = 0
processing_event = asyncio.Event()
async def slow_process(message):
async def slow_process(message, sender):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.05)
concurrent_count -= 1
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = slow_process
sender = AsyncMock()
# Launch more tasks than max_workers
messages = [
{"id": f"msg-{i}", "service": "test", "request": {}}
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
]
tasks = [
asyncio.create_task(dispatcher.handle_message(m))
asyncio.create_task(dispatcher.handle_message(m, sender))
for m in messages
]
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
original_process = dispatcher._process_message
async def tracking_process(message):
async def tracking_process(message, sender):
nonlocal task_was_tracked
# During processing, our task should be in active_tasks
if len(dispatcher.active_tasks) > 0:
task_was_tracked = True
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = tracking_process
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
{"id": "test", "service": "test", "request": {}},
AsyncMock(),
)
assert task_was_tracked
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
"""Semaphore should be released even if processing raises."""
dispatcher = MessageDispatcher(max_workers=2)
async def failing_process(message):
async def failing_process(message, sender):
raise RuntimeError("process failed")
dispatcher._process_message = failing_process
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
# Should not deadlock — semaphore must be released on error
with pytest.raises(RuntimeError):
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
{"id": "test", "service": "test", "request": {}},
AsyncMock(),
)
# Semaphore should be back at max
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
order = []
async def ordered_process(message):
async def ordered_process(message, sender):
msg_id = message["id"]
order.append(f"start-{msg_id}")
await asyncio.sleep(0.02)
order.append(f"end-{msg_id}")
return {"id": msg_id, "response": {"ok": True}}
dispatcher._process_message = ordered_process
sender = AsyncMock()
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
await asyncio.gather(*tasks)
# With semaphore=1, each message should complete before next starts

View file

@ -0,0 +1,56 @@
"""
Tests for ontology monitoring metrics.
"""
from trustgraph.query.ontology.monitoring import (
PerformanceMonitor,
_extract_metric_label,
)
def test_extract_metric_label_reads_unquoted_label_value():
metric_name = "cache_requests_total{cache_type=entity,component=ontology}"
assert _extract_metric_label(metric_name, "cache_type") == "entity"
def test_extract_metric_label_reads_quoted_label_value():
metric_name = 'cache_requests_total{cache_type="entity",component="ontology"}'
assert _extract_metric_label(metric_name, "cache_type") == "entity"
def test_extract_metric_label_returns_none_when_label_missing():
metric_name = "cache_requests_total{component=ontology}"
assert _extract_metric_label(metric_name, "cache_type") is None
def test_performance_report_ignores_counters_without_cache_type_label():
monitor = PerformanceMonitor({"enabled": False})
monitor.metrics_collector.increment(
"cache_requests_total",
labels={"component": "ontology"},
)
monitor.metrics_collector.increment(
"cache_type=not_a_label",
labels={"component": "ontology"},
)
monitor.metrics_collector.increment(
"cache_requests_total",
labels={"cache_type": "entity"},
)
monitor.metrics_collector.increment(
"cache_hits_total",
labels={"cache_type": "entity"},
)
report = monitor.get_performance_report()
assert report["cache_performance"] == {
"entity": {
"hit_rate": 1.0,
"total_requests": 1.0,
"total_hits": 1.0,
}
}

View file

@ -14,7 +14,7 @@ from rdflib.plugins.sparql.parserutils import CompValue
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.algebra import (
evaluate, _query_pattern, _eval_bgp,
evaluate, materialise, _query_pattern, _eval_bgp,
)
@ -28,6 +28,32 @@ def lit(v):
return Term(type=LITERAL, value=v)
def make_tc(query_return=None, query_side_effect=None):
"""Create a mock TriplesClient with both query() and query_gen() support."""
tc = AsyncMock()
if query_side_effect is not None:
tc.query.side_effect = query_side_effect
async def gen_side_effect(**kwargs):
results = await query_side_effect(**kwargs)
for r in results:
yield r
tc.query_gen = gen_side_effect
else:
items = query_return or []
tc.query.return_value = items
async def gen(**kwargs):
for item in items:
yield item
tc.query_gen = gen
return tc
def make_triple(s, p, o):
t = MagicMock()
t.s = s
@ -84,6 +110,20 @@ def make_distinct(inner):
return node
def make_filter(inner, expr):
node = CompValue("Filter")
node.p = inner
node.expr = expr
return node
def make_minus(left, right):
node = CompValue("Minus")
node.p1 = left
node.p2 = right
return node
class TestQueryPattern:
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
@ -136,15 +176,14 @@ class TestEvalBgp:
@pytest.mark.asyncio
async def test_single_pattern_all_variables(self):
tc = AsyncMock()
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc.query.return_value = [triple]
tc = make_tc(query_return=[triple])
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default", limit=100)
solutions = await materialise(bgp, tc, collection="default", limit=100)
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://s"
@ -153,43 +192,37 @@ class TestEvalBgp:
@pytest.mark.asyncio
async def test_single_pattern_bound_subject(self):
tc = AsyncMock()
tc.query.return_value = [
tc = make_tc(query_return=[
make_triple(iri("http://s"), iri("http://p"), lit("val")),
]
])
bgp = make_bgp(
(URIRef("http://s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default")
solutions = await materialise(bgp, tc, collection="default")
tc.query.assert_called_once()
kwargs = tc.query.call_args.kwargs
assert "workspace" not in kwargs
assert kwargs["collection"] == "default"
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_empty_bgp_returns_empty_solution(self):
tc = AsyncMock()
tc = make_tc()
bgp = make_bgp()
solutions = await evaluate(bgp, tc, collection="default")
solutions = await materialise(bgp, tc, collection="default")
assert solutions == [{}]
tc.query.assert_not_called()
@pytest.mark.asyncio
async def test_no_results_returns_empty(self):
tc = AsyncMock()
tc.query.return_value = []
tc = make_tc(query_return=[])
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default")
solutions = await materialise(bgp, tc, collection="default")
assert solutions == []
@ -199,17 +232,16 @@ class TestEvaluate:
@pytest.mark.asyncio
async def test_select_query_node(self):
tc = AsyncMock()
tc.query.return_value = [
tc = make_tc(query_return=[
make_triple(iri("http://s"), iri("http://p"), lit("o")),
]
])
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
select = make_select(make_project(bgp, ["s", "p"]))
solutions = await evaluate(select, tc, collection="default")
solutions = await materialise(select, tc, collection="default")
assert len(solutions) == 1
assert "s" in solutions[0]
@ -220,10 +252,9 @@ class TestEvaluate:
async def test_workspace_never_in_query_calls(self):
"""Verify that no matter the algebra structure, workspace is never
passed to TriplesClient.query()."""
tc = AsyncMock()
tc.query.return_value = [
tc = make_tc(query_return=[
make_triple(iri("http://s"), iri("http://p"), lit("o")),
]
])
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
@ -231,72 +262,319 @@ class TestEvaluate:
make_union(bgp1, bgp2), ["s", "p", "o"]
))
await evaluate(tree, tc, collection="test-coll")
for c in tc.query.call_args_list:
assert "workspace" not in c.kwargs
await materialise(tree, tc, collection="test-coll")
@pytest.mark.asyncio
async def test_join(self):
tc = AsyncMock()
tc.query.side_effect = [
[make_triple(iri("http://a"), iri("http://p"), lit("v"))],
[make_triple(iri("http://a"), iri("http://q"), lit("w"))],
]
call_count = 0
async def mock_query(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return [make_triple(iri("http://a"), iri("http://p"), lit("v"))]
else:
return [make_triple(iri("http://a"), iri("http://q"), lit("w"))]
tc = make_tc(query_side_effect=mock_query)
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
tree = make_join(bgp1, bgp2)
solutions = await evaluate(tree, tc, collection="default")
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://a"
@pytest.mark.asyncio
async def test_slice(self):
tc = AsyncMock()
triples = [
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
for i in range(5)
]
tc.query.return_value = triples
tc = make_tc(query_return=triples)
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_slice(bgp, start=1, length=2)
solutions = await evaluate(tree, tc, collection="default")
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 2
@pytest.mark.asyncio
async def test_distinct(self):
tc = AsyncMock()
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc.query.return_value = [triple, triple]
tc = make_tc(query_return=[triple, triple])
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_distinct(bgp)
solutions = await evaluate(tree, tc, collection="default")
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_minus_removes_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
hates = iri("http://example.com/hates")
charlie = iri("http://example.com/charlie")
left_triple = make_triple(alice, knows, bob)
right_triple2 = make_triple(alice, hates, charlie)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple]
elif pred and pred.iri == "http://example.com/hates":
return [right_triple2]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
right_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/hates"), Variable("r"))
)
tree = make_select(
make_project(
make_minus(left_bgp, right_bgp),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 0
@pytest.mark.asyncio
async def test_minus_no_shared_vars_preserves_all(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
left_triple = make_triple(alice, iri("http://example.com/p"), bob)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/p":
return [left_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/p"), Variable("o"))
)
right_bgp = make_bgp(
(Variable("x"), URIRef("http://example.com/q"), Variable("y"))
)
tree = make_select(
make_project(
make_minus(left_bgp, right_bgp),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_filter_exists_keeps_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
charlie = iri("http://example.com/charlie")
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple1, left_triple2]
elif pred and pred.iri == "http://example.com/likes":
return [exists_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
exists_bgp = make_bgp(
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
)
exists_expr = CompValue("Builtin_EXISTS")
exists_expr.graph = exists_bgp
tree = make_select(
make_project(
make_filter(left_bgp, exists_expr),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
result_objects = [s["o"].iri for s in solutions]
assert "http://example.com/bob" in result_objects
assert "http://example.com/charlie" not in result_objects
@pytest.mark.asyncio
async def test_filter_not_exists_removes_matching(self):
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
charlie = iri("http://example.com/charlie")
left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob)
left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie)
exists_triple = make_triple(bob, iri("http://example.com/likes"), alice)
async def mock_query(**kwargs):
pred = kwargs.get("p")
if pred and pred.iri == "http://example.com/knows":
return [left_triple1, left_triple2]
elif pred and pred.iri == "http://example.com/likes":
return [exists_triple]
return []
tc = make_tc(query_side_effect=mock_query)
left_bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o"))
)
exists_bgp = make_bgp(
(Variable("o"), URIRef("http://example.com/likes"), Variable("_any"))
)
not_exists_expr = CompValue("Builtin_NOTEXISTS")
not_exists_expr.graph = exists_bgp
tree = make_select(
make_project(
make_filter(left_bgp, not_exists_expr),
["s", "o"]
)
)
solutions = await materialise(tree, tc, collection="default")
result_objects = [s["o"].iri for s in solutions]
assert "http://example.com/charlie" in result_objects
assert "http://example.com/bob" not in result_objects
@pytest.mark.asyncio
async def test_join_values_uses_bind_join(self):
"""When VALUES is joined with a BGP, the bind join should pass
the VALUES bindings into the BGP evaluation so the triple store
query is selective (not a wildcard)."""
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
queries_issued = []
async def mock_query(**kwargs):
queries_issued.append(kwargs)
s, p = kwargs.get("s"), kwargs.get("p")
if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows":
return [make_triple(alice, knows, bob)]
return []
tc = make_tc(query_side_effect=mock_query)
# VALUES ?s { <alice> }
values_node = CompValue("values")
values_node.var = [Variable("s")]
values_node.value = [[URIRef("http://example.com/alice")]]
values_node.res = None
to_multiset = CompValue("ToMultiSet")
to_multiset.p = values_node
bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
)
tree = make_join(to_multiset, bgp)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://example.com/alice"
assert solutions[0]["o"].iri == "http://example.com/bob"
# The key assertion: the BGP query should have received
# s=alice (bound from VALUES), NOT s=None (wildcard)
assert len(queries_issued) == 1
assert queries_issued[0]["s"] is not None
assert queries_issued[0]["s"].iri == "http://example.com/alice"
@pytest.mark.asyncio
async def test_join_values_multiple_bindings(self):
"""Bind join with multiple VALUES bindings."""
alice = iri("http://example.com/alice")
bob = iri("http://example.com/bob")
knows = iri("http://example.com/knows")
charlie = iri("http://example.com/charlie")
async def mock_query(**kwargs):
s = kwargs.get("s")
if s and s.iri == "http://example.com/alice":
return [make_triple(alice, knows, bob)]
elif s and s.iri == "http://example.com/bob":
return [make_triple(bob, knows, charlie)]
return []
tc = make_tc(query_side_effect=mock_query)
values_node = CompValue("values")
values_node.var = [Variable("s")]
values_node.value = [
[URIRef("http://example.com/alice")],
[URIRef("http://example.com/bob")],
]
values_node.res = None
to_multiset = CompValue("ToMultiSet")
to_multiset.p = values_node
bgp = make_bgp(
(Variable("s"), URIRef("http://example.com/knows"), Variable("o")),
)
tree = make_join(to_multiset, bgp)
solutions = await materialise(tree, tc, collection="default")
assert len(solutions) == 2
subjects = {s["s"].iri for s in solutions}
assert subjects == {
"http://example.com/alice",
"http://example.com/bob",
}
@pytest.mark.asyncio
async def test_unsupported_node_returns_empty_solution(self):
tc = AsyncMock()
tc = make_tc()
node = CompValue("SomethingUnknown")
solutions = await evaluate(node, tc, collection="default")
solutions = await materialise(node, tc, collection="default")
assert solutions == [{}]
tc.query.assert_not_called()
@pytest.mark.asyncio
async def test_non_compvalue_returns_empty_solution(self):
tc = AsyncMock()
tc = make_tc()
solutions = await evaluate("not a node", tc, collection="default")
solutions = await materialise("not a node", tc, collection="default")
assert solutions == [{}]

View file

@ -300,6 +300,438 @@ class TestBuiltinFunctions:
flags=None)
assert evaluate_expression(expr, {"x": lit("hello")}) is False
def test_substr_three_args(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(1),
length=Literal(4))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "2024"
def test_substr_two_args(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(6),
length=None)
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "03-15"
def test_substr_middle(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Literal(6),
length=Literal(2))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.type == LITERAL
assert result.value == "03"
def test_substr_null_start(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("SUBSTR",
arg=Variable("x"),
start=Variable("missing"),
length=None)
result = evaluate_expression(expr, {"x": lit("hello")})
assert result is None
def test_year(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 2024
def test_month(self):
from rdflib.term import Variable
expr = self._make_builtin("MONTH", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 3
def test_day(self):
from rdflib.term import Variable
expr = self._make_builtin("DAY", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 15
def test_hours(self):
from rdflib.term import Variable
expr = self._make_builtin("HOURS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 10
def test_minutes(self):
from rdflib.term import Variable
expr = self._make_builtin("MINUTES", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 30
def test_seconds(self):
from rdflib.term import Variable
expr = self._make_builtin("SECONDS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 45
def test_year_from_datetime(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")}
)
assert result == 2024
def test_hours_from_date_returns_zero(self):
from rdflib.term import Variable
expr = self._make_builtin("HOURS", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15", datatype=XSD + "date")}
)
assert result == 0
def test_year_invalid_date(self):
from rdflib.term import Variable
expr = self._make_builtin("YEAR", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("not-a-date")}
)
assert result is None
def test_floor(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.7")}) == 3
def test_floor_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-2.3")}) == -3
def test_floor_none(self):
from rdflib.term import Variable
expr = self._make_builtin("FLOOR", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_ceil(self):
from rdflib.term import Variable
expr = self._make_builtin("CEIL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.2")}) == 4
def test_ceil_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("CEIL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-2.7")}) == -2
def test_abs_positive(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("42")}) == 42
def test_abs_negative(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("-42")}) == 42
def test_abs_none(self):
from rdflib.term import Variable
expr = self._make_builtin("ABS", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_replace_simple(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal(" BC"),
replacement=Literal(""),
flags=None)
result = evaluate_expression(expr, {"x": lit("500 BC")})
assert result.type == LITERAL
assert result.value == "500"
def test_replace_regex(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal("[0-9]+"),
replacement=Literal("X"),
flags=None)
result = evaluate_expression(expr, {"x": lit("abc123def456")})
assert result.value == "abcXdefX"
def test_replace_case_insensitive(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REPLACE",
arg=Variable("x"),
pattern=Literal("hello"),
replacement=Literal("world"),
flags=Literal("i"))
result = evaluate_expression(expr, {"x": lit("HELLO there")})
assert result.value == "world there"
def test_round_up(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.7")}) == 4
def test_round_down(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("3.2")}) == 3
def test_round_none(self):
from rdflib.term import Variable
expr = self._make_builtin("ROUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("abc")}) is None
def test_strbefore(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRBEFORE",
arg1=Variable("x"), arg2=Literal("-"))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.value == "2024"
def test_strbefore_not_found(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRBEFORE",
arg1=Variable("x"), arg2=Literal("/"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == ""
def test_strafter(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRAFTER",
arg1=Variable("x"), arg2=Literal("-"))
result = evaluate_expression(expr, {"x": lit("2024-03-15")})
assert result.value == "03-15"
def test_strafter_not_found(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRAFTER",
arg1=Variable("x"), arg2=Literal("/"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == ""
def test_encode_for_uri(self):
from rdflib.term import Variable
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello world")})
assert result.value == "hello%20world"
def test_encode_for_uri_special_chars(self):
from rdflib.term import Variable
expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("a/b?c=d&e")})
assert result.value == "a%2Fb%3Fc%3Dd%26e"
def test_langmatches_basic(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("en"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_subtag(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("en-US"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_wildcard(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("fr"), arg2=Literal("*"))
assert evaluate_expression(expr, {}) is True
def test_langmatches_wildcard_empty(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal(""), arg2=Literal("*"))
assert evaluate_expression(expr, {}) is False
def test_langmatches_no_match(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("LANGMATCHES",
arg1=Literal("fr"), arg2=Literal("en"))
assert evaluate_expression(expr, {}) is False
def test_iri_constructor(self):
from rdflib.term import Variable
expr = self._make_builtin("IRI", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("http://example.com/test")}
)
assert result.type == IRI
assert result.iri == "http://example.com/test"
def test_uri_constructor(self):
from rdflib.term import Variable
expr = self._make_builtin("URI", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("http://example.com/test")}
)
assert result.type == IRI
assert result.iri == "http://example.com/test"
def test_bnode_no_arg(self):
expr = self._make_builtin("BNODE")
result = evaluate_expression(expr, {})
assert result.type == BLANK
assert len(result.id) > 0
def test_bnode_with_label(self):
from rdflib import Literal
expr = self._make_builtin("BNODE", arg=Literal("mynode"))
result = evaluate_expression(expr, {})
assert result.type == BLANK
assert result.id == "mynode"
def test_now(self):
import re as re_mod
expr = self._make_builtin("NOW")
result = evaluate_expression(expr, {})
assert result.type == LITERAL
assert result.datatype == XSD + "dateTime"
assert re_mod.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", result.value)
def test_tz_with_utc(self):
from rdflib.term import Variable
expr = self._make_builtin("TZ", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45+0000",
datatype=XSD + "dateTime")}
)
assert result.type == LITERAL
assert result.value == "+00:00"
def test_tz_no_timezone(self):
from rdflib.term import Variable
expr = self._make_builtin("TZ", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("2024-03-15T10:30:45",
datatype=XSD + "dateTime")}
)
assert result.value == ""
def test_rand(self):
expr = self._make_builtin("RAND")
result = evaluate_expression(expr, {})
assert isinstance(result, float)
assert 0.0 <= result < 1.0
def test_uuid(self):
import re as re_mod
expr = self._make_builtin("UUID")
result = evaluate_expression(expr, {})
assert result.type == IRI
assert result.iri.startswith("urn:uuid:")
uuid_part = result.iri[len("urn:uuid:"):]
assert re_mod.match(
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
uuid_part
)
def test_struuid(self):
import re as re_mod
expr = self._make_builtin("STRUUID")
result = evaluate_expression(expr, {})
assert result.type == LITERAL
assert re_mod.match(
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
result.value
)
def test_md5(self):
from rdflib.term import Variable
expr = self._make_builtin("MD5", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == "5d41402abc4b2a76b9719d911017c592"
def test_sha1(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA1", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d"
def test_sha256(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA256", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert result.value == (
"2cf24dba5fb0a30e26e83b2ac5b9e29e"
"1b161e5c1fa7425e73043362938b9824"
)
def test_sha512(self):
from rdflib.term import Variable
expr = self._make_builtin("SHA512", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.type == LITERAL
assert len(result.value) == 128
def test_exists_with_callback(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("EXISTS", graph=graph)
cb = lambda g, s: True
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is True
def test_exists_callback_false(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("EXISTS", graph=graph)
cb = lambda g, s: False
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is False
def test_notexists_with_callback(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("NOTEXISTS", graph=graph)
cb = lambda g, s: True
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is False
def test_notexists_callback_false(self):
from rdflib.plugins.sparql.parserutils import CompValue
graph = CompValue("BGP")
expr = self._make_builtin("NOTEXISTS", graph=graph)
cb = lambda g, s: False
result = evaluate_expression(expr, {}, exists_cb=cb)
assert result is True
class TestEffectiveBoolean:

View file

@ -5,7 +5,7 @@ Tests for SPARQL solution sequence operations.
import pytest
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.solutions import (
hash_join, left_join, union, project, distinct,
hash_join, left_join, minus, union, project, distinct,
order_by, slice_solutions, _terms_equal, _compatible,
)
@ -311,6 +311,30 @@ class TestOrderBy:
result = order_by(solutions, [])
assert len(result) == 1
def test_order_by_numeric_literals(self):
solutions = [
{"year": lit("1950")},
{"year": lit("700")},
{"year": lit("2000")},
{"year": lit("450")},
{"year": lit("1200")},
]
key_fns = [(lambda sol: sol.get("year"), True)]
result = order_by(solutions, key_fns)
values = [s["year"].value for s in result]
assert values == ["450", "700", "1200", "1950", "2000"]
def test_order_by_numeric_descending(self):
solutions = [
{"year": lit("1950")},
{"year": lit("700")},
{"year": lit("2000")},
]
key_fns = [(lambda sol: sol.get("year"), False)]
result = order_by(solutions, key_fns)
values = [s["year"].value for s in result]
assert values == ["2000", "1950", "700"]
class TestSlice:
@ -343,3 +367,37 @@ class TestSlice:
solutions = [{"s": alice}, {"s": bob}]
result = slice_solutions(solutions)
assert len(result) == 2
class TestMinus:
def test_removes_compatible(self, alice, bob):
left = [{"s": alice}, {"s": bob}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/bob"
def test_empty_right_preserves_all(self, alice, bob):
left = [{"s": alice}, {"s": bob}]
result = minus(left, [])
assert len(result) == 2
def test_no_shared_variables_preserves_all(self, alice, bob):
left = [{"s": alice}]
right = [{"t": bob}]
result = minus(left, right)
assert len(result) == 1
def test_all_removed(self, alice):
left = [{"s": alice}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 0
def test_partial_shared_variables(self, alice, bob):
left = [{"s": alice, "p": lit("x")}, {"s": bob, "p": lit("y")}]
right = [{"s": alice}]
result = minus(left, right)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/bob"

View file

View file

@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch, ANY
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
class TestWebSocketResponder:
"""Test cases for WebSocketResponder class"""
def test_websocket_responder_initialization(self):
"""Test WebSocketResponder initialization"""
responder = WebSocketResponder()
assert responder.response is None
assert responder.completed is False
@pytest.mark.asyncio
async def test_websocket_responder_send_method(self):
"""Test WebSocketResponder send method"""
responder = WebSocketResponder()
test_response = {"data": "test response"}
# Call send method
await responder.send(test_response)
# Verify response was stored
assert responder.response == test_response
@pytest.mark.asyncio
async def test_websocket_responder_call_method(self):
"""Test WebSocketResponder __call__ method"""
responder = WebSocketResponder()
test_response = {"result": "success"}
test_completed = True
# Call the responder
await responder(test_response, test_completed)
# Verify response and completed status were set
assert responder.response == test_response
assert responder.completed == test_completed
@pytest.mark.asyncio
async def test_websocket_responder_call_method_with_false_completion(self):
"""Test WebSocketResponder __call__ method with incomplete response"""
responder = WebSocketResponder()
test_response = {"partial": "data"}
test_completed = False
# Call the responder
await responder(test_response, test_completed)
# Verify response was set and completed is True (since send() always sets completed=True)
assert responder.response == test_response
assert responder.completed is True
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestMessageDispatcher:
"""Test cases for MessageDispatcher class"""
def test_message_dispatcher_initialization_with_defaults(self):
"""Test MessageDispatcher initialization with default parameters"""
dispatcher = MessageDispatcher()
assert dispatcher.max_workers == 10
assert dispatcher.semaphore._value == 10
assert dispatcher.active_tasks == set()
assert dispatcher.backend is None
assert dispatcher.auth is None
assert dispatcher.dispatcher_manager is None
assert len(dispatcher.service_mapping) > 0
def test_message_dispatcher_initialization_with_custom_workers(self):
"""Test MessageDispatcher initialization with custom max_workers"""
dispatcher = MessageDispatcher(max_workers=5)
assert dispatcher.max_workers == 5
assert dispatcher.semaphore._value == 5
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
def test_message_dispatcher_initialization_with_backend(
self, mock_dispatcher_manager,
):
mock_backend = MagicMock()
mock_config_receiver = MagicMock()
mock_auth = MagicMock()
mock_dispatcher_instance = MagicMock()
mock_dispatcher_manager.return_value = mock_dispatcher_instance
dispatcher = MessageDispatcher(
max_workers=8,
config_receiver=mock_config_receiver,
backend=mock_backend
backend=mock_backend,
auth=mock_auth,
timeout=300,
)
assert dispatcher.max_workers == 8
assert dispatcher.backend == mock_backend
assert dispatcher.auth == mock_auth
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
mock_dispatcher_manager.assert_called_once_with(
mock_backend, mock_config_receiver, prefix="rev-gateway"
mock_backend, mock_config_receiver,
auth=mock_auth, prefix="rev-gateway", timeout=300,
)
def test_message_dispatcher_service_mapping(self):
"""Test MessageDispatcher service mapping contains expected services"""
dispatcher = MessageDispatcher()
expected_services = [
"text-completion", "graph-rag", "agent", "embeddings",
"graph-embeddings", "triples", "document-load", "text-load",
"flow", "knowledge", "config", "librarian", "document-rag"
"flow", "knowledge", "config", "librarian", "document-rag",
]
for service in expected_services:
assert service in dispatcher.service_mapping
# Test specific mappings
assert dispatcher.service_mapping["text-completion"] == "text-completion"
assert dispatcher.service_mapping["document-load"] == "document"
assert dispatcher.service_mapping["text-load"] == "text-document"
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
"""Test MessageDispatcher handle_message without dispatcher manager"""
async def test_handle_message_without_dispatcher_manager(self):
dispatcher = MessageDispatcher()
test_message = {
"id": "test-123",
"service": "test-service",
"request": {"data": "test"}
}
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-123"
assert "error" in result["response"]
assert "DispatcherManager not available" in result["response"]["error"]
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="default")
)
sender = AsyncMock()
await dispatcher.handle_message(
{"id": "test-1", "service": "test", "request": {}},
sender,
)
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-1"
assert sent["error"]["message"] == "DispatcherManager not available"
assert sent["error"]["type"] == "error"
assert sent["complete"] is True
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_with_exception(self):
"""Test MessageDispatcher handle_message with exception during processing"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
async def test_handle_message_auth_failure(self):
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-456",
"service": "text-completion",
"request": {"prompt": "test"}
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-456"
assert "error" in result["response"]
assert "Test error" in result["response"]["error"]
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
side_effect=Exception("auth failure")
)
dispatcher.dispatcher_manager = MagicMock()
sender = AsyncMock()
await dispatcher.handle_message(
{"id": "test-2", "token": "bad", "service": "test", "request": {}},
sender,
)
sender.assert_called_once()
sent = sender.call_args[0][0]
assert sent["id"] == "test-2"
assert "auth failure" in sent["error"]["message"]
assert sent["complete"] is True
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_global_service(self):
"""Test MessageDispatcher handle_message with global service"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_global_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"result": "success"}
async def test_handle_message_global_service(self):
mock_dm = MagicMock()
mock_dm.invoke_global_service = AsyncMock()
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-789",
"service": "text-completion",
"request": {"prompt": "hello"}
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-789"
assert result["response"] == {"result": "success"}
mock_dispatcher_manager.invoke_global_service.assert_called_once()
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws1")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-3",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hello"},
},
sender,
)
mock_dm.invoke_global_service.assert_called_once()
args, kwargs = mock_dm.invoke_global_service.call_args
assert args[0] == {"prompt": "hello"}
assert args[2] == "text-completion"
assert kwargs["workspace"] == "ws1"
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_flow_service(self):
"""Test MessageDispatcher handle_message with flow service"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = True
mock_responder.response = {"data": "flow_result"}
async def test_handle_message_flow_service(self):
mock_dm = MagicMock()
mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-flow-123",
"service": "document-rag",
"request": {"query": "test"},
"flow": "custom-flow"
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-flow-123"
assert result["response"] == {"data": "flow_result"}
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws2")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-4",
"token": "tg_key",
"service": "document-rag",
"request": {"query": "test"},
"flow": "my-flow",
},
sender,
)
mock_dm.invoke_flow_service.assert_called_once_with(
{"query": "test"}, ANY, "ws2", "my-flow", "document-rag",
)
@pytest.mark.asyncio
async def test_message_dispatcher_handle_message_incomplete_response(self):
"""Test MessageDispatcher handle_message with incomplete response"""
mock_dispatcher_manager = MagicMock()
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
mock_responder = MagicMock()
mock_responder.completed = False
mock_responder.response = None
async def test_handle_message_responder_sends_frames(self):
mock_dm = MagicMock()
async def fake_invoke(data, responder, svc, workspace=None):
await responder({"partial": 1}, False)
await responder({"partial": 2}, True)
mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke)
dispatcher = MessageDispatcher()
dispatcher.dispatcher_manager = mock_dispatcher_manager
test_message = {
"id": "test-incomplete",
"service": "agent",
"request": {"input": "test"}
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="ws1")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers',
{"text-completion": True},
):
await dispatcher.handle_message(
{
"id": "test-5",
"token": "tg_key",
"service": "text-completion",
"request": {"prompt": "hi"},
},
sender,
)
assert sender.call_count == 2
first = sender.call_args_list[0][0][0]
second = sender.call_args_list[1][0][0]
assert first == {
"id": "test-5", "response": {"partial": 1}, "complete": False,
}
assert second == {
"id": "test-5", "response": {"partial": 2}, "complete": True,
}
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
result = await dispatcher.handle_message(test_message)
assert result["id"] == "test-incomplete"
assert result["response"] == {"error": "No response received"}
@pytest.mark.asyncio
async def test_message_dispatcher_shutdown(self):
"""Test MessageDispatcher shutdown method"""
import asyncio
async def test_handle_message_workspace_from_identity(self):
mock_dm = MagicMock()
mock_dm.invoke_flow_service = AsyncMock()
dispatcher = MessageDispatcher()
# Create actual async tasks
dispatcher.dispatcher_manager = mock_dm
dispatcher.auth = MagicMock()
dispatcher.auth.authenticate = AsyncMock(
return_value=MagicMock(workspace="derived-ws")
)
sender = AsyncMock()
with patch(
'trustgraph.gateway.dispatch.manager.global_dispatchers', {},
):
await dispatcher.handle_message(
{
"id": "test-6",
"token": "tg_key",
"service": "agent",
"request": {"question": "test"},
"flow": "default",
},
sender,
)
args = mock_dm.invoke_flow_service.call_args[0]
assert args[2] == "derived-ws"
@pytest.mark.asyncio
async def test_shutdown(self):
dispatcher = MessageDispatcher()
async def dummy_task():
await asyncio.sleep(0.01)
return "done"
task1 = asyncio.create_task(dummy_task())
task2 = asyncio.create_task(dummy_task())
dispatcher.active_tasks = {task1, task2}
# Call shutdown
await dispatcher.shutdown()
# Verify tasks were completed
assert task1.done()
assert task2.done()
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
@pytest.mark.asyncio
async def test_message_dispatcher_shutdown_with_no_tasks(self):
"""Test MessageDispatcher shutdown with no active tasks"""
async def test_shutdown_with_no_tasks(self):
dispatcher = MessageDispatcher()
# Call shutdown with no active tasks
await dispatcher.shutdown()
# Should complete without error
assert dispatcher.active_tasks == set()
assert dispatcher.active_tasks == set()

View file

@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock
from aiohttp import WSMsgType, ClientWebSocketResponse
import json
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run
from trustgraph.rev_gateway.service import ReverseGateway, run
MOCK_PATCHES = [
'trustgraph.rev_gateway.service.IamAuth',
'trustgraph.rev_gateway.service.ConfigReceiver',
'trustgraph.rev_gateway.service.MessageDispatcher',
'trustgraph.rev_gateway.service.get_pubsub',
]
def make_gateway(**overrides):
config = {"websocket_uri": "ws://localhost:7650/out"}
config.update(overrides)
return ReverseGateway(**config)
class TestReverseGateway:
"""Test cases for ReverseGateway class"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with default parameters"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_defaults(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
assert gateway.websocket_uri == "ws://localhost:7650/out"
assert gateway.host == "localhost"
assert gateway.port == 7650
@ -33,25 +49,22 @@ class TestReverseGateway:
assert gateway.max_workers == 10
assert gateway.running is False
assert gateway.reconnect_delay == 3.0
assert gateway.pulsar_host == "pulsar://pulsar:6650"
assert gateway.pulsar_api_key is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with custom parameters"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_custom_params(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(
websocket_uri="wss://example.com:8080/websocket",
max_workers=20,
pulsar_host="pulsar://custom:6650",
pulsar_api_key="test-key",
pulsar_listener="test-listener"
)
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
assert gateway.host == "example.com"
assert gateway.port == 8080
@ -59,340 +72,360 @@ class TestReverseGateway:
assert gateway.path == "/websocket"
assert gateway.url == "wss://example.com:8080/websocket"
assert gateway.max_workers == 20
assert gateway.pulsar_host == "pulsar://custom:6650"
assert gateway.pulsar_api_key == "test-key"
assert gateway.pulsar_listener == "test-listener"
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with WebSocket URI missing path"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(websocket_uri="ws://example.com")
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_with_missing_path(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway(websocket_uri="ws://example.com")
assert gateway.path == "/ws"
assert gateway.url == "ws://example.com/ws"
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_invalid_scheme(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
ReverseGateway(websocket_uri="http://example.com")
make_gateway(websocket_uri="http://example.com")
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with missing hostname"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_initialization_missing_hostname(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
ReverseGateway(websocket_uri="ws://")
make_gateway(websocket_uri="ws://")
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway creates backend with authentication"""
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_iam_auth_created(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(
pulsar_api_key="test-key",
pulsar_listener="test-listener"
gateway = make_gateway(id="test-rev-gw")
mock_iam_auth.assert_called_once_with(
backend=mock_backend,
id="test-rev-gw",
)
# Verify get_pubsub was called with the correct parameters
mock_get_pubsub.assert_called_once_with(
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key="test-key",
pulsar_listener="test-listener"
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_config_receiver_gets_auth(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_auth_instance = MagicMock()
mock_iam_auth.return_value = mock_auth_instance
gateway = make_gateway()
mock_config_receiver.assert_called_once_with(
mock_backend, auth=mock_auth_instance,
)
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway successful connection"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_connect_success(
self, mock_session_class, mock_get_pubsub,
mock_dispatcher, mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock()
mock_ws = AsyncMock()
mock_session.ws_connect.return_value = mock_ws
mock_session_class.return_value = mock_session
gateway = ReverseGateway()
gateway = make_gateway()
result = await gateway.connect()
assert result is True
assert gateway.session == mock_session
assert gateway.ws == mock_ws
mock_session.ws_connect.assert_called_once_with(gateway.url)
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway connection failure"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_connect_failure(
self, mock_session_class, mock_get_pubsub,
mock_dispatcher, mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_session = AsyncMock()
mock_session.ws_connect.side_effect = Exception("Connection failed")
mock_session_class.return_value = mock_session
gateway = ReverseGateway()
gateway = make_gateway()
result = await gateway.connect()
assert result is False
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway disconnect"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock websocket and session
async def test_reverse_gateway_disconnect(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = False
mock_session = AsyncMock()
mock_session.closed = False
gateway.ws = mock_ws
gateway.session = mock_session
await gateway.disconnect()
mock_ws.close.assert_called_once()
mock_session.close.assert_called_once()
assert gateway.ws is None
assert gateway.session is None
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway send message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock websocket
async def test_reverse_gateway_send_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message)
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway send message with closed connection"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock closed websocket
async def test_reverse_gateway_send_message_closed_connection(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
mock_ws = AsyncMock()
mock_ws.closed = True
gateway.ws = mock_ws
test_message = {"id": "test", "data": "hello"}
await gateway.send_message(test_message)
# Should not call send_str on closed connection
mock_ws.send_str.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock()
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway()
# Mock send_message
gateway.send_message = AsyncMock()
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
await gateway.handle_message(test_message)
mock_dispatcher_instance.handle_message.assert_called_once_with({
"id": "test",
"service": "test-service",
"request": {"data": "test"}
})
gateway.send_message.assert_called_once_with({"response": "success"})
async def test_reverse_gateway_handle_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = make_gateway()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message with invalid JSON"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
# Mock send_message
gateway.send_message = AsyncMock()
test_message = 'invalid json'
# Should not raise exception
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
await gateway.handle_message(test_message)
# Should not call send_message due to error
mock_dispatcher_instance.handle_message.assert_called_once_with(
{
"id": "test",
"service": "test-service",
"request": {"data": "test"},
},
gateway.send_message,
)
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.send_message = AsyncMock()
await gateway.handle_message('invalid json')
gateway.send_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with text message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_text_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.TEXT
mock_msg.data = '{"test": "message"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "message"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with binary message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_binary_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.BINARY
mock_msg.data = b'{"test": "binary"}'
# Mock receive to return message once, then raise exception to stop loop
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
# listen() catches exceptions and breaks, so no exception should be raised
await gateway.listen()
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with close message"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
async def test_reverse_gateway_listen_close_message(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
# Mock websocket
mock_ws = AsyncMock()
mock_ws.closed = False
gateway.ws = mock_ws
# Mock handle_message
gateway.handle_message = AsyncMock()
# Mock message
mock_msg = MagicMock()
mock_msg.type = WSMsgType.CLOSE
# Mock receive to return close message
mock_ws.receive.return_value = mock_msg
await gateway.listen()
# Should not call handle_message for close message
gateway.handle_message.assert_not_called()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway shutdown"""
async def test_reverse_gateway_shutdown(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance
gateway = ReverseGateway()
gateway = make_gateway()
gateway.running = True
# Mock disconnect
gateway.disconnect = AsyncMock()
await gateway.shutdown()
@ -402,46 +435,50 @@ class TestReverseGateway:
gateway.disconnect.assert_called_once()
mock_backend.close.assert_called_once()
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway stop"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway()
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
def test_reverse_gateway_stop(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
gateway = make_gateway()
gateway.running = True
gateway.stop()
assert gateway.running is False
class TestReverseGatewayRun:
"""Test cases for ReverseGateway run method"""
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('trustgraph.rev_gateway.service.get_pubsub')
@patch(*MOCK_PATCHES[0:1])
@patch(*MOCK_PATCHES[1:2])
@patch(*MOCK_PATCHES[2:3])
@patch(*MOCK_PATCHES[3:4])
@pytest.mark.asyncio
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway run method with successful connect/listen cycle"""
mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
async def test_reverse_gateway_run_successful_cycle(
self, mock_get_pubsub, mock_dispatcher,
mock_config_receiver, mock_iam_auth,
):
mock_get_pubsub.return_value = MagicMock()
mock_auth_instance = AsyncMock()
mock_iam_auth.return_value = mock_auth_instance
mock_config_receiver_instance = AsyncMock()
mock_config_receiver.return_value = mock_config_receiver_instance
gateway = ReverseGateway()
# Mock methods
gateway.connect = AsyncMock(return_value=True)
gateway = make_gateway()
gateway.listen = AsyncMock()
gateway.disconnect = AsyncMock()
gateway.shutdown = AsyncMock()
# Stop after one iteration
call_count = 0
async def mock_connect():
nonlocal call_count
@ -451,91 +488,13 @@ class TestReverseGatewayRun:
else:
gateway.running = False
return False
gateway.connect = mock_connect
await gateway.run()
mock_auth_instance.start.assert_called_once()
mock_config_receiver_instance.start.assert_called_once()
gateway.listen.assert_called_once()
# disconnect is called twice: once in the main loop, once in shutdown
assert gateway.disconnect.call_count == 2
gateway.shutdown.assert_called_once()
class TestReverseGatewayArgs:
"""Test cases for argument parsing and run function"""
def test_parse_args_defaults(self):
"""Test parse_args with default values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway']
try:
args = parse_args()
assert args.websocket_uri is None
assert args.max_workers == 10
assert args.pulsar_host is None
assert args.pulsar_api_key is None
assert args.pulsar_listener is None
finally:
sys.argv = original_argv
def test_parse_args_custom_values(self):
"""Test parse_args with custom values"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = [
'reverse-gateway',
'--websocket-uri', 'ws://custom:8080/ws',
'--max-workers', '20',
'--pulsar-host', 'pulsar://custom:6650',
'--pulsar-api-key', 'test-key',
'--pulsar-listener', 'test-listener'
]
try:
args = parse_args()
assert args.websocket_uri == 'ws://custom:8080/ws'
assert args.max_workers == 20
assert args.pulsar_host == 'pulsar://custom:6650'
assert args.pulsar_api_key == 'test-key'
assert args.pulsar_listener == 'test-listener'
finally:
sys.argv = original_argv
@patch('trustgraph.rev_gateway.service.ReverseGateway')
@patch('asyncio.run')
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
"""Test run function"""
import sys
# Mock sys.argv
original_argv = sys.argv
sys.argv = ['reverse-gateway', '--max-workers', '15']
try:
mock_gateway_instance = MagicMock()
mock_gateway_instance.url = "ws://localhost:7650/out"
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
mock_gateway_class.return_value = mock_gateway_instance
run()
mock_gateway_class.assert_called_once_with(
websocket_uri=None,
max_workers=15,
pulsar_host=None,
pulsar_api_key=None,
pulsar_listener=None
)
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
finally:
sys.argv = original_argv

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import Any
from . request_response_spec import RequestResponse, RequestResponseSpec
@ -44,6 +45,60 @@ def from_value(x: Any) -> Any:
return Term(type=LITERAL, value=str(x))
class TriplesClient(RequestResponse):
async def query_gen(self, s=None, p=None, o=None, limit=20,
collection="default",
batch_size=20, timeout=30, g=None):
"""Async generator yielding Triple objects as batches arrive."""
queue = asyncio.Queue()
done = False
async def recipient(resp):
if resp.error:
raise RuntimeError(resp.error.message)
batch = [
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
for v in resp.triples
]
await queue.put(batch)
if resp.is_final:
await queue.put(None)
return resp.is_final
# Launch the streaming request as a background task
task = asyncio.ensure_future(self.request(
TriplesQueryRequest(
s=from_value(s),
p=from_value(p),
o=from_value(o),
limit=limit,
collection=collection,
streaming=True,
batch_size=batch_size,
g=g,
),
timeout=timeout,
recipient=recipient,
))
try:
while True:
batch = await queue.get()
if batch is None:
break
for triple in batch:
yield triple
finally:
if not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
async def query(self, s=None, p=None, o=None, limit=20,
collection="default",
timeout=30, g=None):

View file

@ -7,7 +7,7 @@ Provides semantic query understanding, ontology matching, and answer generation.
from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse
from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset
from .backend_router import BackendRouter, BackendType, QueryRoute
from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
@ -27,7 +27,7 @@ __all__ = [
'QuestionType',
# Ontology matching
'OntologyMatcher',
'OntologyMatcherForQueries',
'QueryOntologySubset',
# Backend routing

View file

@ -4,6 +4,7 @@ Provides comprehensive monitoring of system performance, query patterns, and res
"""
import logging
import re
import time
import asyncio
import inspect
@ -276,6 +277,26 @@ class MetricsCollector:
return f"{name}{{{label_str}}}"
def _extract_metric_label(metric_name: str, label: str) -> Optional[str]:
"""Extract a label value from an internal metric key."""
labels_start = metric_name.find('{')
labels_end = metric_name.find('}', labels_start + 1)
if labels_start == -1 or labels_end == -1:
return None
labels = metric_name[labels_start + 1:labels_end]
label_match = re.search(
rf'(?:^|,){re.escape(label)}=(?:"([^"]*)"|([^,]*))',
labels,
)
if not label_match:
return None
quoted_value, unquoted_value = label_match.groups()
return quoted_value if quoted_value is not None else unquoted_value
class PerformanceMonitor:
"""Monitors system performance and component health."""
@ -474,8 +495,8 @@ class PerformanceMonitor:
# Cache performance
cache_types = set()
for metric_name in self.metrics_collector.counters.keys():
if 'cache_type=' in metric_name:
cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0]
cache_type = _extract_metric_label(metric_name, 'cache_type')
if cache_type is not None:
cache_types.add(cache_type)
for cache_type in cache_types:

View file

@ -7,10 +7,10 @@ import logging
from typing import List, Dict, Any, Set, Optional
from dataclasses import dataclass
from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder
from ...extract.kg.ontology.text_processor import TextSegment
from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
from trustgraph.extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
from trustgraph.extract.kg.ontology.ontology_embedder import OntologyEmbedder
from trustgraph.extract.kg.ontology.text_processor import TextSegment
from trustgraph.extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
from .question_analyzer import QuestionComponents, QuestionType
logger = logging.getLogger(__name__)

View file

@ -8,13 +8,13 @@ from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from datetime import datetime
from ....flow.flow_processor import FlowProcessor
from ....tables.config import ConfigTableStore
from ...extract.kg.ontology.ontology_loader import OntologyLoader
from ...extract.kg.ontology.vector_store import InMemoryVectorStore
from trustgraph.base.flow_processor import FlowProcessor
from trustgraph.tables.config import ConfigTableStore
from trustgraph.extract.kg.ontology.ontology_loader import OntologyLoader
from trustgraph.extract.kg.ontology.vector_store import InMemoryVectorStore
from .question_analyzer import QuestionAnalyzer, QuestionComponents
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset
from .backend_router import BackendRouter, QueryRoute, BackendType
from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
@ -105,7 +105,7 @@ class OntoRAGQueryService(FlowProcessor):
# Initialize ontology matcher
matcher_config = self.config.get('ontology_matcher', {})
self.ontology_matcher = OntologyMatcher(
self.ontology_matcher = OntologyMatcherForQueries(
vector_store=self.vector_store,
embedding_service=self.embedding_service,
config=matcher_config

View file

@ -28,7 +28,7 @@ try:
except ImportError:
CASSANDRA_AVAILABLE = False
from ....tables.config import ConfigTableStore
from trustgraph.tables.config import ConfigTableStore
logger = logging.getLogger(__name__)

View file

@ -4,6 +4,10 @@ SPARQL algebra evaluator.
Recursively evaluates an rdflib SPARQL algebra tree by issuing triple
pattern queries via TriplesClient (streaming) and performing in-memory
joins, filters, and projections.
Handlers are async generators that yield solutions incrementally.
Blocking operators (joins, sort, group, distinct) materialise their
upstream into a list at the boundary, then yield results.
"""
import logging
@ -17,7 +21,7 @@ from ... knowledge import Uri
from ... knowledge import Literal as KgLiteral
from . parser import rdflib_term_to_term
from . solutions import (
hash_join, left_join, union, project, distinct,
hash_join, left_join, minus, union, project, distinct,
order_by, slice_solutions, _term_key,
)
from . expressions import evaluate_expression, _effective_boolean
@ -34,56 +38,56 @@ async def evaluate(node, triples_client, collection, limit=10000):
"""
Evaluate a SPARQL algebra node.
Args:
node: rdflib CompValue algebra node
triples_client: TriplesClient instance for triple pattern queries
collection: collection identifier
limit: safety limit on results
Returns:
list of solutions (dicts mapping variable names to Term values)
Yields solutions (dicts mapping variable names to Term values)
incrementally as an async generator.
"""
if not isinstance(node, CompValue):
logger.warning(f"Expected CompValue, got {type(node)}: {node}")
return [{}]
yield {}
return
name = node.name
handler = _HANDLERS.get(name)
if handler is None:
logger.warning(f"Unsupported algebra node: {name}")
return [{}]
yield {}
return
return await handler(node, triples_client, collection, limit)
async for sol in handler(node, triples_client, collection, limit):
yield sol
# --- Node handlers ---
async def materialise(node, triples_client, collection, limit=10000):
"""Collect all solutions from evaluate() into a list."""
return [sol async for sol in evaluate(node, triples_client, collection, limit)]
# --- Node handlers (async generators) ---
async def _eval_select_query(node, tc, collection, limit):
"""Evaluate a SelectQuery node."""
return await evaluate(node.p, tc, collection, limit)
async for sol in evaluate(node.p, tc, collection, limit):
yield sol
async def _eval_project(node, tc, collection, limit):
"""Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, collection, limit)
variables = [str(v) for v in node.PV]
return project(solutions, variables)
async for sol in evaluate(node.p, tc, collection, limit):
yield {v: sol[v] for v in variables if v in sol}
async def _eval_bgp(node, tc, collection, limit):
"""
Evaluate a Basic Graph Pattern.
Issues streaming triple pattern queries and joins results. Patterns
are ordered by selectivity (more bound terms first) and evaluated
sequentially with bound-variable substitution.
Patterns are ordered by selectivity and evaluated sequentially.
For the final pattern, results stream directly from the triple store.
"""
triples = node.triples
if not triples:
return [{}]
yield {}
return
# Sort patterns by selectivity: more bound terms = more selective
def selectivity(pattern):
return sum(1 for t in pattern if not isinstance(t, Variable))
@ -91,55 +95,222 @@ async def _eval_bgp(node, tc, collection, limit):
enumerate(triples), key=lambda x: -selectivity(x[1])
)
# For all patterns except the last, we must materialise intermediate
# solutions because each pattern depends on bindings from prior ones.
# The last pattern streams directly.
solutions = [{}]
for _, pattern in sorted_patterns:
for pattern_idx, (_, pattern) in enumerate(sorted_patterns):
s_tmpl, p_tmpl, o_tmpl = pattern
is_last = (pattern_idx == len(sorted_patterns) - 1)
new_solutions = []
if is_last:
# Stream the final pattern — yield as triples arrive
count = 0
for sol in solutions:
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
for sol in solutions:
# Substitute known bindings into the pattern
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
async for triple in tc.query_gen(
s=s_val, p=p_val, o=o_val,
limit=limit, collection=collection,
):
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
yield binding
count += 1
if count >= limit:
return
else:
# Materialise intermediate patterns
new_solutions = []
for sol in solutions:
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
# Query the triples store
results = await _query_pattern(
tc, s_val, p_val, o_val, collection, limit
)
async for triple in tc.query_gen(
s=s_val, p=p_val, o=o_val,
limit=limit, collection=collection,
):
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
new_solutions.append(binding)
# Map results back to variable bindings,
# converting Uri/Literal to Term objects
for triple in results:
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
new_solutions.append(binding)
solutions = new_solutions
if not solutions:
return
solutions = new_solutions
if not solutions:
break
# --- Blocking operators: materialise upstream, then yield ---
return solutions[:limit]
def _is_small_node(node):
"""Check if a node is likely to produce a small number of solutions."""
if not isinstance(node, CompValue):
return False
if node.name in ("values", "ToMultiSet"):
return True
if node.name == "Extend" and hasattr(node, "p"):
return _is_small_node(node.p)
return False
async def _eval_join(node, tc, collection, limit):
"""Evaluate a Join node."""
left = await evaluate(node.p1, tc, collection, limit)
right = await evaluate(node.p2, tc, collection, limit)
return hash_join(left, right)[:limit]
# Bind join: if one side is small (e.g. VALUES), materialise it and
# substitute its bindings into the other side's evaluation. This
# turns wildcard BGP queries into selective ones.
if _is_small_node(node.p1):
yield_from = _bind_join(node.p1, node.p2, tc, collection, limit)
elif _is_small_node(node.p2):
yield_from = _bind_join(node.p2, node.p1, tc, collection, limit)
else:
yield_from = _hash_join(node, tc, collection, limit)
async for sol in yield_from:
yield sol
async def _hash_join(node, tc, collection, limit):
left = await materialise(node.p1, tc, collection, limit)
right = await materialise(node.p2, tc, collection, limit)
for sol in hash_join(left, right)[:limit]:
yield sol
async def _bind_join(small_node, big_node, tc, collection, limit):
"""Iterate over the small side and inject bindings into the big side."""
small_sols = await materialise(small_node, tc, collection, limit)
count = 0
for binding in small_sols:
async for sol in _evaluate_with_bindings(
big_node, binding, tc, collection, limit
):
yield sol
count += 1
if count >= limit:
return
def _merge_compatible(left, right):
"""Merge two solutions if compatible (shared vars have equal values)."""
merged = dict(left)
for k, v in right.items():
if k in merged:
if _term_key(merged[k]) != _term_key(v):
return None
else:
merged[k] = v
return merged
async def _evaluate_with_bindings(node, bindings, tc, collection, limit):
"""Evaluate a node with pre-seeded variable bindings.
For BGP nodes, the bindings are injected so _resolve_term sees them,
turning wildcard queries into selective ones. For other node types,
evaluate normally and merge/filter against the bindings.
"""
if isinstance(node, CompValue) and node.name == "BGP":
async for sol in _eval_bgp_with_bindings(
node, bindings, tc, collection, limit
):
yield sol
else:
async for sol in evaluate(node, tc, collection, limit):
merged = _merge_compatible(bindings, sol)
if merged is not None:
yield merged
async def _eval_bgp_with_bindings(node, bindings, tc, collection, limit):
"""Evaluate a BGP with pre-seeded bindings so variables resolve to terms."""
triples = node.triples
if not triples:
yield dict(bindings)
return
def selectivity(pattern):
score = 0
for t in pattern:
if not isinstance(t, Variable):
score += 1
elif str(t) in bindings:
score += 1
return score
sorted_patterns = sorted(
enumerate(triples), key=lambda x: -selectivity(x[1])
)
solutions = [dict(bindings)]
for pattern_idx, (_, pattern) in enumerate(sorted_patterns):
s_tmpl, p_tmpl, o_tmpl = pattern
is_last = (pattern_idx == len(sorted_patterns) - 1)
if is_last:
count = 0
for sol in solutions:
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
async for triple in tc.query_gen(
s=s_val, p=p_val, o=o_val,
limit=limit, collection=collection,
):
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
yield binding
count += 1
if count >= limit:
return
else:
new_solutions = []
for sol in solutions:
s_val = _resolve_term(s_tmpl, sol)
p_val = _resolve_term(p_tmpl, sol)
o_val = _resolve_term(o_tmpl, sol)
async for triple in tc.query_gen(
s=s_val, p=p_val, o=o_val,
limit=limit, collection=collection,
):
binding = dict(sol)
if isinstance(s_tmpl, Variable):
binding[str(s_tmpl)] = _to_term(triple.s)
if isinstance(p_tmpl, Variable):
binding[str(p_tmpl)] = _to_term(triple.p)
if isinstance(o_tmpl, Variable):
binding[str(o_tmpl)] = _to_term(triple.o)
new_solutions.append(binding)
solutions = new_solutions
if not solutions:
return
async def _eval_left_join(node, tc, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL)."""
left_sols = await evaluate(node.p1, tc, collection, limit)
right_sols = await evaluate(node.p2, tc, collection, limit)
# Buffer right side for hash index; stream left through probe
left_sols = await materialise(node.p1, tc, collection, limit)
right_sols = await materialise(node.p2, tc, collection, limit)
filter_fn = None
if hasattr(node, "expr") and node.expr is not None:
@ -149,42 +320,35 @@ async def _eval_left_join(node, tc, collection, limit):
evaluate_expression(expr, sol)
)
return left_join(left_sols, right_sols, filter_fn)[:limit]
for sol in left_join(left_sols, right_sols, filter_fn)[:limit]:
yield sol
async def _eval_union(node, tc, collection, limit):
"""Evaluate a Union node."""
left = await evaluate(node.p1, tc, collection, limit)
right = await evaluate(node.p2, tc, collection, limit)
return union(left, right)[:limit]
async def _eval_filter(node, tc, collection, limit):
"""Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, collection, limit)
expr = node.expr
return [
sol for sol in solutions
if _effective_boolean(evaluate_expression(expr, sol))
]
async def _eval_minus(node, tc, collection, limit):
left = await materialise(node.p1, tc, collection, limit)
right = await materialise(node.p2, tc, collection, limit)
for sol in minus(left, right):
yield sol
async def _eval_distinct(node, tc, collection, limit):
"""Evaluate a Distinct node."""
solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions)
seen = set()
async for sol in evaluate(node.p, tc, collection, limit):
key = tuple(sorted(
(k, _term_key(v)) for k, v in sol.items()
))
if key not in seen:
seen.add(key)
yield sol
async def _eval_reduced(node, tc, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined)."""
# Treat same as Distinct
solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions)
async for sol in _eval_distinct(node, tc, collection, limit):
yield sol
async def _eval_order_by(node, tc, collection, limit):
"""Evaluate an OrderBy node."""
solutions = await evaluate(node.p, tc, collection, limit)
solutions = await materialise(node.p, tc, collection, limit)
key_fns = []
for cond in node.expr:
@ -196,36 +360,104 @@ async def _eval_order_by(node, tc, collection, limit):
ascending,
))
else:
# Simple variable or expression
key_fns.append((
lambda sol, e=cond: evaluate_expression(e, sol),
True,
))
return order_by(solutions, key_fns)
for sol in order_by(solutions, key_fns):
yield sol
# --- Streamable operators ---
async def _eval_slice(node, tc, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET)."""
# Pass tighter limit downstream if possible
inner_limit = limit
if node.length is not None:
offset = node.start or 0
inner_limit = min(limit, offset + node.length)
offset = node.start or 0
length = node.length
skipped = 0
emitted = 0
solutions = await evaluate(node.p, tc, collection, inner_limit)
return slice_solutions(solutions, node.start or 0, node.length)
async for sol in evaluate(node.p, tc, collection, limit):
if skipped < offset:
skipped += 1
continue
yield sol
emitted += 1
if length is not None and emitted >= length:
return
async def _eval_union(node, tc, collection, limit):
async for sol in evaluate(node.p1, tc, collection, limit):
yield sol
async for sol in evaluate(node.p2, tc, collection, limit):
yield sol
async def _check_exists(graph_node, sol, tc, collection, limit):
"""Evaluate an EXISTS graph pattern against a solution."""
async for r in evaluate(graph_node, tc, collection, limit):
shared = set(sol.keys()) & set(r.keys())
if all(
_term_key(sol[v]) == _term_key(r[v])
for v in shared
if sol.get(v) is not None and r.get(v) is not None
):
return True
return False
async def _pre_eval_exists(expr, sol, tc, collection, limit, cache):
"""Walk an expression tree, pre-evaluate EXISTS/NOT EXISTS, cache results."""
if not isinstance(expr, CompValue):
return
if expr.name in ("Builtin_EXISTS", "Builtin_NOTEXISTS"):
key = id(expr.graph), id(sol)
if key not in cache:
cache[key] = await _check_exists(
expr.graph, sol, tc, collection, limit
)
return
for attr in ("expr", "other", "arg", "arg1", "arg2", "arg3"):
child = getattr(expr, attr, None)
if child is None:
continue
if isinstance(child, CompValue):
await _pre_eval_exists(child, sol, tc, collection, limit, cache)
elif isinstance(child, (list, tuple)):
for item in child:
if isinstance(item, CompValue):
await _pre_eval_exists(
item, sol, tc, collection, limit, cache
)
async def _eval_filter(node, tc, collection, limit):
expr = node.expr
exists_cache = {}
def exists_cb(graph_node, sol):
key = id(graph_node), id(sol)
return exists_cache.get(key, False)
async for sol in evaluate(node.p, tc, collection, limit):
await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache)
if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)):
yield sol
async def _eval_extend(node, tc, collection, limit):
"""Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, collection, limit)
var_name = str(node.var)
expr = node.expr
exists_cache = {}
result = []
for sol in solutions:
val = evaluate_expression(expr, sol)
def exists_cb(graph_node, sol):
key = id(graph_node), id(sol)
return exists_cache.get(key, False)
async for sol in evaluate(node.p, tc, collection, limit):
await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache)
val = evaluate_expression(expr, sol, exists_cb=exists_cb)
new_sol = dict(sol)
if isinstance(val, Term):
new_sol[var_name] = val
@ -240,16 +472,14 @@ async def _eval_extend(node, tc, collection, limit):
)
elif val is not None:
new_sol[var_name] = Term(type=LITERAL, value=str(val))
result.append(new_sol)
yield new_sol
return result
# --- Aggregation (blocking) ---
async def _eval_group(node, tc, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation)."""
solutions = await evaluate(node.p, tc, collection, limit)
solutions = await materialise(node.p, tc, collection, limit)
# Extract grouping expressions
group_exprs = []
if hasattr(node, "expr") and node.expr:
for expr in node.expr:
@ -260,7 +490,6 @@ async def _eval_group(node, tc, collection, limit):
else:
group_exprs.append((expr, None))
# Group solutions
groups = defaultdict(list)
for sol in solutions:
key_parts = []
@ -270,81 +499,72 @@ async def _eval_group(node, tc, collection, limit):
groups[tuple(key_parts)].append(sol)
if not group_exprs:
# No GROUP BY - entire result is one group
groups[()].extend(solutions)
# Build grouped solutions (one per group)
result = []
for key, group_sols in groups.items():
sol = {}
# Include group key variables
if group_sols:
for (expr, var_name), k in zip(group_exprs, key):
if var_name and group_sols:
sol[var_name] = evaluate_expression(expr, group_sols[0])
sol["__group__"] = group_sols
result.append(sol)
return result
yield sol
async def _eval_aggregate_join(node, tc, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
solutions = await evaluate(node.p, tc, collection, limit)
result = []
for sol in solutions:
async for sol in evaluate(node.p, tc, collection, limit):
group = sol.get("__group__", [sol])
new_sol = {k: v for k, v in sol.items() if k != "__group__"}
# Apply aggregate functions
if hasattr(node, "A") and node.A:
for agg in node.A:
var_name = str(agg.res)
agg_val = _compute_aggregate(agg, group)
new_sol[var_name] = agg_val
result.append(new_sol)
return result
yield new_sol
async def _eval_graph(node, tc, collection, limit):
"""Evaluate a Graph node (GRAPH clause)."""
term = node.term
if isinstance(term, URIRef):
# GRAPH <uri> { ... } — fixed graph
# We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, collection, limit)
elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, collection, limit)
else:
return await evaluate(node.p, tc, collection, limit)
async for sol in evaluate(node.p, tc, collection, limit):
yield sol
async def _eval_values(node, tc, collection, limit):
"""Evaluate a VALUES clause (inline data)."""
variables = [str(v) for v in node.var]
solutions = []
# rdflib has two representations for VALUES:
# 1. var=[Variable...], value=[[val, ...], ...] — positional
# 2. var=None, res=[{Variable: val, ...}, ...] — dict-based
if hasattr(node, "res") and node.res:
for row in node.res:
sol = {}
for var, val in row.items():
if val is not None and str(val) != "UNDEF":
sol[str(var)] = rdflib_term_to_term(val)
yield sol
return
if not node.var or not node.value:
yield {}
return
variables = [str(v) for v in node.var]
for row in node.value:
sol = {}
for var_name, val in zip(variables, row):
if val is not None and str(val) != "UNDEF":
sol[var_name] = rdflib_term_to_term(val)
solutions.append(sol)
return solutions
yield sol
async def _eval_to_multiset(node, tc, collection, limit):
"""Evaluate a ToMultiSet node (subquery)."""
return await evaluate(node.p, tc, collection, limit)
async for sol in evaluate(node.p, tc, collection, limit):
yield sol
# --- Aggregate computation ---
@ -353,7 +573,6 @@ def _compute_aggregate(agg, group):
"""Compute a single aggregate function over a group of solutions."""
agg_name = agg.name if hasattr(agg, "name") else ""
# Get the expression to aggregate
expr = agg.vars if hasattr(agg, "vars") else None
if agg_name == "Aggregate_Count":
@ -525,6 +744,7 @@ _HANDLERS = {
"Join": _eval_join,
"LeftJoin": _eval_left_join,
"Union": _eval_union,
"Minus": _eval_minus,
"Filter": _eval_filter,
"Distinct": _eval_distinct,
"Reduced": _eval_reduced,

View file

@ -5,9 +5,15 @@ Evaluates rdflib algebra expression nodes against a solution (variable
binding) to produce a value or boolean result.
"""
import hashlib
import math
import random
import re
import logging
import operator
import uuid
from datetime import datetime, date, timezone
from urllib.parse import quote
from rdflib.term import Variable, URIRef, Literal, BNode
from rdflib.plugins.sparql.parserutils import CompValue
@ -17,23 +23,31 @@ from . parser import rdflib_term_to_term
logger = logging.getLogger(__name__)
_exists_callback = None
class ExpressionError(Exception):
"""Raised when a SPARQL expression cannot be evaluated."""
pass
def evaluate_expression(expr, solution):
def evaluate_expression(expr, solution, exists_cb=None):
"""
Evaluate a SPARQL expression against a solution binding.
Args:
expr: rdflib algebra expression node
solution: dict mapping variable names to Term values
exists_cb: optional callback(graph_node, solution) -> bool for
EXISTS/NOT EXISTS evaluation; provided by algebra.py
Returns:
The result value (Term, bool, number, string, or None)
"""
global _exists_callback
if exists_cb is not None:
_exists_callback = exists_cb
if expr is None:
return True
@ -111,6 +125,13 @@ def _evaluate_comp_value(node, solution):
if name == "MultiplicativeExpression":
return _eval_multiplicative(node, solution)
# IN / NOT IN — must be checked before the generic Builtin_ dispatch
if name == "Builtin_IN":
return _eval_in(node, solution)
if name == "Builtin_NOTIN":
return not _eval_in(node, solution)
# SPARQL built-in functions
if name.startswith("Builtin_"):
return _eval_builtin(name, node, solution)
@ -119,27 +140,10 @@ def _evaluate_comp_value(node, solution):
if name == "Function":
return _eval_function(node, solution)
# Exists / NotExists
if name == "Builtin_EXISTS":
# EXISTS requires graph pattern evaluation - not handled here
logger.warning("EXISTS not supported in filter expressions")
return True
if name == "Builtin_NOTEXISTS":
logger.warning("NOT EXISTS not supported in filter expressions")
return True
# TrueFilter (used with OPTIONAL)
if name == "TrueFilter":
return True
# IN / NOT IN
if name == "Builtin_IN":
return _eval_in(node, solution)
if name == "Builtin_NOTIN":
return not _eval_in(node, solution)
logger.warning(f"Unknown CompValue expression: {name}")
return None
@ -165,6 +169,22 @@ def _eval_relational(node, solution):
">=": operator.ge,
}
if str(op) == "IN":
items = node.other if isinstance(node.other, list) else [node.other]
for item in items:
other_val = evaluate_expression(item, solution)
if _comparable_value(left) == _comparable_value(other_val):
return True
return False
if str(op) == "NOT IN":
items = node.other if isinstance(node.other, list) else [node.other]
for item in items:
other_val = evaluate_expression(item, solution)
if _comparable_value(left) == _comparable_value(other_val):
return False
return True
op_fn = ops.get(str(op))
if op_fn is None:
logger.warning(f"Unknown relational operator: {op}")
@ -335,6 +355,197 @@ def _eval_builtin(name, node, solution):
return val
return None
if builtin == "YEAR":
dt = _to_datetime(evaluate_expression(node.arg, solution))
return dt.year if dt is not None else None
if builtin == "MONTH":
dt = _to_datetime(evaluate_expression(node.arg, solution))
return dt.month if dt is not None else None
if builtin == "DAY":
dt = _to_datetime(evaluate_expression(node.arg, solution))
return dt.day if dt is not None else None
if builtin == "HOURS":
dt = _to_datetime(evaluate_expression(node.arg, solution))
if dt is None:
return None
return dt.hour if isinstance(dt, datetime) else 0
if builtin == "MINUTES":
dt = _to_datetime(evaluate_expression(node.arg, solution))
if dt is None:
return None
return dt.minute if isinstance(dt, datetime) else 0
if builtin == "SECONDS":
dt = _to_datetime(evaluate_expression(node.arg, solution))
if dt is None:
return None
return dt.second if isinstance(dt, datetime) else 0
if builtin == "FLOOR":
val = _to_numeric(evaluate_expression(node.arg, solution))
if val is None:
return None
return int(math.floor(val))
if builtin == "CEIL":
val = _to_numeric(evaluate_expression(node.arg, solution))
if val is None:
return None
return int(math.ceil(val))
if builtin == "ABS":
val = _to_numeric(evaluate_expression(node.arg, solution))
if val is None:
return None
return abs(val)
if builtin == "ROUND":
val = _to_numeric(evaluate_expression(node.arg, solution))
if val is None:
return None
return round(val)
if builtin == "STRBEFORE":
string = _to_string(evaluate_expression(node.arg1, solution))
sep = _to_string(evaluate_expression(node.arg2, solution))
idx = string.find(sep)
if idx < 0:
return Term(type=LITERAL, value="")
return Term(type=LITERAL, value=string[:idx])
if builtin == "STRAFTER":
string = _to_string(evaluate_expression(node.arg1, solution))
sep = _to_string(evaluate_expression(node.arg2, solution))
idx = string.find(sep)
if idx < 0:
return Term(type=LITERAL, value="")
return Term(type=LITERAL, value=string[idx + len(sep):])
if builtin == "ENCODE_FOR_URI":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(type=LITERAL, value=quote(val, safe=""))
if builtin == "REPLACE":
string = _to_string(evaluate_expression(node.arg, solution))
pattern = _to_string(evaluate_expression(node.pattern, solution))
replacement = _to_string(
evaluate_expression(node.replacement, solution)
)
flags_str = ""
if hasattr(node, "flags") and node.flags is not None:
flags_str = _to_string(evaluate_expression(node.flags, solution))
re_flags = 0
if "i" in flags_str:
re_flags |= re.IGNORECASE
if "m" in flags_str:
re_flags |= re.MULTILINE
if "s" in flags_str:
re_flags |= re.DOTALL
try:
result = re.sub(pattern, replacement, string, flags=re_flags)
return Term(type=LITERAL, value=result)
except re.error:
return None
if builtin == "SUBSTR":
string = _to_string(evaluate_expression(node.arg, solution))
start = _to_numeric(evaluate_expression(node.start, solution))
if start is None:
return None
start_idx = max(int(start) - 1, 0)
if hasattr(node, "length") and node.length is not None:
length = _to_numeric(evaluate_expression(node.length, solution))
if length is None:
return None
return Term(
type=LITERAL, value=string[start_idx:start_idx + int(length)]
)
return Term(type=LITERAL, value=string[start_idx:])
if builtin == "EXISTS":
if _exists_callback is not None:
return _exists_callback(node.graph, solution)
logger.warning("EXISTS requires an exists_cb; not available")
return True
if builtin == "NOTEXISTS":
if _exists_callback is not None:
return not _exists_callback(node.graph, solution)
logger.warning("NOT EXISTS requires an exists_cb; not available")
return True
if builtin == "LANGMATCHES":
tag = _to_string(evaluate_expression(node.arg1, solution))
rng = _to_string(evaluate_expression(node.arg2, solution))
if rng == "*":
return len(tag) > 0
return tag.lower().startswith(rng.lower())
if builtin == "IRI" or builtin == "URI":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(type=IRI, iri=val)
if builtin == "BNODE":
if hasattr(node, "arg") and node.arg is not None:
label = _to_string(evaluate_expression(node.arg, solution))
return Term(type=BLANK, id=label)
return Term(type=BLANK, id=str(uuid.uuid4()))
if builtin == "NOW":
now = datetime.now(timezone.utc)
return Term(
type=LITERAL,
value=now.strftime("%Y-%m-%dT%H:%M:%S%z"),
datatype="http://www.w3.org/2001/XMLSchema#dateTime",
)
if builtin == "TZ":
dt = _to_datetime(evaluate_expression(node.arg, solution))
if dt is None:
return Term(type=LITERAL, value="")
if dt.tzinfo is not None:
offset = dt.strftime("%z")
if offset:
return Term(type=LITERAL, value=offset[:3] + ":" + offset[3:])
return Term(type=LITERAL, value="")
if builtin == "RAND":
return random.random()
if builtin == "UUID":
return Term(type=IRI, iri="urn:uuid:" + str(uuid.uuid4()))
if builtin == "STRUUID":
return Term(type=LITERAL, value=str(uuid.uuid4()))
if builtin == "MD5":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(
type=LITERAL, value=hashlib.md5(val.encode()).hexdigest()
)
if builtin == "SHA1":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(
type=LITERAL, value=hashlib.sha1(val.encode()).hexdigest()
)
if builtin == "SHA256":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(
type=LITERAL, value=hashlib.sha256(val.encode()).hexdigest()
)
if builtin == "SHA512":
val = _to_string(evaluate_expression(node.arg, solution))
return Term(
type=LITERAL, value=hashlib.sha512(val.encode()).hexdigest()
)
if builtin == "sameTerm":
left = evaluate_expression(node.arg1, solution)
right = evaluate_expression(node.arg2, solution)
@ -454,6 +665,27 @@ def _to_numeric(val):
return None
def _to_datetime(val):
"""Convert a value to a date or datetime object."""
if val is None:
return None
s = _to_string(val)
for fmt in (
"%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z",
"%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M",
"%Y-%m-%d",
):
try:
return datetime.strptime(s, fmt)
except ValueError:
continue
try:
return datetime.fromisoformat(s)
except ValueError:
pass
return None
def _comparable_value(val):
"""
Convert a value to a form suitable for comparison.

View file

@ -12,7 +12,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import TriplesClientSpec
from . parser import parse_sparql, ParseError
from . algebra import evaluate, EvaluationError
from . algebra import evaluate, materialise, EvaluationError
logger = logging.getLogger(__name__)
@ -66,11 +66,10 @@ class Processor(FlowProcessor):
logger.debug(f"Handling SPARQL query request {id}...")
response = await self.execute_sparql(request, flow)
if request.streaming and response.query_type == "select":
await self.send_streaming(response, flow, id, request)
if request.streaming:
await self.execute_sparql_streaming(request, flow, id)
else:
response = await self.execute_sparql(request, flow)
await flow("response").send(
response, properties={"id": id}
)
@ -92,37 +91,77 @@ class Processor(FlowProcessor):
await flow("response").send(r, properties={"id": id})
async def send_streaming(self, response, flow, id, request):
"""Send SELECT results in batches."""
async def execute_sparql_streaming(self, request, flow, id):
"""Execute a SPARQL query and stream results as they arrive."""
bindings = response.bindings
try:
parsed = parse_sparql(request.query)
except ParseError as e:
await flow("response").send(
SparqlQueryResponse(
error=Error(
type="sparql-parse-error",
message=str(e),
),
),
properties={"id": id}
)
return
if parsed.query_type != "select":
response = await self._execute_non_select(parsed, request, flow)
await flow("response").send(response, properties={"id": id})
return
triples_client = flow("triples-request")
variables = parsed.variables
batch_size = request.batch_size if request.batch_size > 0 else 20
batch = []
for i in range(0, len(bindings), batch_size):
batch = bindings[i:i + batch_size]
is_final = (i + batch_size >= len(bindings))
r = SparqlQueryResponse(
query_type=response.query_type,
variables=response.variables,
bindings=batch,
is_final=is_final,
)
await flow("response").send(r, properties={"id": id})
try:
async for sol in evaluate(
parsed.algebra,
triples_client,
collection=request.collection or "default",
limit=request.limit or 10000,
):
values = [sol.get(v) for v in variables]
batch.append(SparqlBinding(values=values))
# Handle empty results
if len(bindings) == 0:
r = SparqlQueryResponse(
query_type=response.query_type,
variables=response.variables,
bindings=[],
is_final=True,
if len(batch) >= batch_size:
r = SparqlQueryResponse(
query_type="select",
variables=variables,
bindings=batch,
is_final=False,
)
await flow("response").send(r, properties={"id": id})
batch = []
except EvaluationError as e:
await flow("response").send(
SparqlQueryResponse(
error=Error(
type="sparql-evaluation-error",
message=str(e),
),
),
properties={"id": id}
)
await flow("response").send(r, properties={"id": id})
return
# Final batch (may be empty for zero results)
r = SparqlQueryResponse(
query_type="select",
variables=variables,
bindings=batch,
is_final=True,
)
await flow("response").send(r, properties={"id": id})
async def execute_sparql(self, request, flow):
"""Parse and evaluate a SPARQL query."""
"""Parse and evaluate a SPARQL query (non-streaming)."""
# Parse the SPARQL query
try:
parsed = parse_sparql(request.query)
except ParseError as e:
@ -133,12 +172,31 @@ class Processor(FlowProcessor):
),
)
# Get the triples client from the flow
triples_client = flow("triples-request")
if parsed.query_type == "select":
triples_client = flow("triples-request")
try:
solutions = await materialise(
parsed.algebra,
triples_client,
collection=request.collection or "default",
limit=request.limit or 10000,
)
except EvaluationError as e:
return SparqlQueryResponse(
error=Error(
type="sparql-evaluation-error",
message=str(e),
),
)
return self._build_select_response(parsed, solutions)
# Evaluate the algebra
return await self._execute_non_select(parsed, request, flow)
async def _execute_non_select(self, parsed, request, flow):
"""Execute ASK, CONSTRUCT, or DESCRIBE queries."""
triples_client = flow("triples-request")
try:
solutions = await evaluate(
solutions = await materialise(
parsed.algebra,
triples_client,
collection=request.collection or "default",
@ -152,10 +210,7 @@ class Processor(FlowProcessor):
),
)
# Build response based on query type
if parsed.query_type == "select":
return self._build_select_response(parsed, solutions)
elif parsed.query_type == "ask":
if parsed.query_type == "ask":
return self._build_ask_response(solutions)
elif parsed.query_type == "construct":
return self._build_construct_response(parsed, solutions)

View file

@ -150,6 +150,30 @@ def left_join(left, right, filter_fn=None):
return results
def minus(left, right):
"""
MINUS operation: remove left solutions that are compatible with any
right solution sharing at least one variable.
"""
if not right:
return list(left)
right_vars = set()
for sol in right:
right_vars.update(sol.keys())
results = []
for sol_l in left:
shared = set(sol_l.keys()) & right_vars
if not shared:
results.append(sol_l)
continue
if not any(_compatible(sol_l, sol_r) for sol_r in right):
results.append(sol_l)
return results
def union(left, right):
"""Union two solution sequences (concatenation)."""
return list(left) + list(right)
@ -177,6 +201,28 @@ def distinct(solutions):
return results
def _sort_comparable(val):
"""Convert a value to a form suitable for sort ordering."""
if val is None:
return (0, "")
if isinstance(val, (int, float)):
return (2, val)
if isinstance(val, Term):
if val.type == LITERAL:
try:
if "." in val.value:
return (2, float(val.value))
return (2, int(val.value))
except (ValueError, TypeError):
pass
return (3, val.value)
elif val.type == IRI:
return (4, val.iri)
elif val.type == BLANK:
return (5, val.id)
return (6, str(val))
def order_by(solutions, key_fns):
"""
Sort solutions by the given key functions.
@ -191,14 +237,7 @@ def order_by(solutions, key_fns):
keys = []
for fn, ascending in key_fns:
val = fn(sol)
# Convert to comparable form
if val is None:
comparable = ("", "")
elif isinstance(val, Term):
comparable = _term_key(val)
else:
comparable = ("v", str(val))
keys.append(comparable)
keys.append(_sort_comparable(val))
return keys
# Handle ascending/descending
@ -224,10 +263,8 @@ def _mixed_sort(solutions, key_fns):
def compare(a, b):
for fn, ascending in key_fns:
va = fn(a)
vb = fn(b)
ka = _term_key(va) if isinstance(va, Term) else ("v", str(va)) if va is not None else ("", "")
kb = _term_key(vb) if isinstance(vb, Term) else ("v", str(vb)) if vb is not None else ("", "")
ka = _sort_comparable(fn(a))
kb = _sort_comparable(fn(b))
if ka < kb:
return -1 if ascending else 1

View file

@ -1,130 +1,140 @@
import asyncio
import logging
import uuid
from typing import Dict, Any, Optional
from trustgraph.messaging import TranslatorRegistry
from typing import Dict, Any, Optional, Callable, Awaitable
from ..gateway.dispatch.manager import DispatcherManager
logger = logging.getLogger("dispatcher")
logger.setLevel(logging.INFO)
class WebSocketResponder:
"""Simple responder that captures response for websocket return"""
def __init__(self):
self.response = None
self.completed = False
async def send(self, data):
"""Capture the response data"""
self.response = data
self.completed = True
async def __call__(self, data, final=False):
"""Make the responder callable for compatibility with requestor"""
await self.send(data)
if final:
self.completed = True
class _TokenShim:
def __init__(self, token):
self.headers = (
{"Authorization": f"Bearer {token}"} if token else {}
)
class MessageDispatcher:
def __init__(self, max_workers: int = 10, config_receiver=None, backend=None):
def __init__(self, max_workers=10, config_receiver=None, backend=None,
auth=None, timeout=120):
self.max_workers = max_workers
self.semaphore = asyncio.Semaphore(max_workers)
self.active_tasks = set()
self.backend = backend
self.auth = auth
# Use DispatcherManager for flow and service management
if backend and config_receiver:
self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway")
if backend and config_receiver and auth:
self.dispatcher_manager = DispatcherManager(
backend, config_receiver,
auth=auth,
prefix="rev-gateway",
timeout=timeout,
)
else:
self.dispatcher_manager = None
logger.warning("No backend or config_receiver provided - using fallback mode")
# Service name mapping from websocket protocol to translator registry
logger.warning(
"Missing backend, config_receiver, or auth "
"— using fallback mode"
)
self.service_mapping = {
"text-completion": "text-completion",
"graph-rag": "graph-rag",
"graph-rag": "graph-rag",
"agent": "agent",
"embeddings": "embeddings",
"graph-embeddings": "graph-embeddings",
"triples": "triples",
"document-load": "document",
"text-load": "text-document",
"text-load": "text-document",
"flow": "flow",
"knowledge": "knowledge",
"config": "config",
"librarian": "librarian",
"document-rag": "document-rag"
"document-rag": "document-rag",
}
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
async def handle_message(
self, message: Dict[Any, Any],
sender: Callable[[dict], Awaitable[None]],
):
async with self.semaphore:
task = asyncio.create_task(self._process_message(message))
task = asyncio.create_task(
self._process_message(message, sender)
)
self.active_tasks.add(task)
try:
result = await task
return result
await task
finally:
self.active_tasks.discard(task)
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
async def _authenticate(self, token):
if not self.auth:
raise RuntimeError("Auth not configured")
return await self.auth.authenticate(_TokenShim(token))
async def _process_message(
self, message: Dict[Any, Any],
sender: Callable[[dict], Awaitable[None]],
):
request_id = message.get('id', str(uuid.uuid4()))
service = message.get('service')
request_data = message.get('request', {})
flow_id = message.get('flow', 'default') # Default flow
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
token = message.get('token', '')
flow_id = message.get('flow', 'default')
logger.info(
f"Processing message {request_id} for service "
f"{service} on flow {flow_id}"
)
try:
if not self.dispatcher_manager:
raise RuntimeError("DispatcherManager not available - backend and config_receiver required")
# Use DispatcherManager for flow-based processing
responder = WebSocketResponder()
# Map websocket service name to dispatcher service name
raise RuntimeError(
"DispatcherManager not available"
)
identity = await self._authenticate(token)
workspace = identity.workspace
async def responder(resp, fin):
await sender({
"id": request_id,
"response": resp,
"complete": fin,
})
dispatcher_service = self.service_mapping.get(service, service)
# Check if this is a global service or flow service
from ..gateway.dispatch.manager import global_dispatchers
if dispatcher_service in global_dispatchers:
# Use global service dispatcher
await self.dispatcher_manager.invoke_global_service(
request_data, responder, dispatcher_service
request_data, responder, dispatcher_service,
workspace=workspace,
)
else:
# Use DispatcherManager to process the request through Pulsar queues
await self.dispatcher_manager.invoke_flow_service(
request_data, responder, flow_id, dispatcher_service
request_data, responder, workspace, flow_id,
dispatcher_service,
)
# Get the response from the responder
if responder.completed:
response_data = responder.response
else:
response_data = {'error': 'No response received'}
response = {
'id': request_id,
'response': response_data
}
except Exception as e:
logger.error(f"Error processing message {request_id}: {e}")
response = {
'id': request_id,
'response': {'error': str(e)}
}
await sender({
"id": request_id,
"error": {"message": str(e), "type": "error"},
"complete": True,
})
logger.info(f"Completed processing message {request_id}")
return response
async def shutdown(self):
if self.active_tasks:
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
logger.info(
f"Waiting for {len(self.active_tasks)} active "
f"tasks to complete"
)
await asyncio.gather(*self.active_tasks, return_exceptions=True)
# DispatcherManager handles its own cleanup
logger.info("Dispatcher shutdown complete")

View file

@ -1,3 +1,9 @@
"""
Reverse gateway. Initiates outbound WebSocket connections to a remote
relay and dispatches incoming requests through the same DispatcherManager
pipeline as api-gateway.
"""
import asyncio
import argparse
import logging
@ -9,82 +15,86 @@ from typing import Optional
from urllib.parse import urlparse, urlunparse
from .dispatcher import MessageDispatcher
from ..gateway.auth import IamAuth
from ..gateway.config.receiver import ConfigReceiver
from ..base import get_pubsub
from ..base.pubsub import get_pubsub, add_pubsub_args
from ..base.logging import setup_logging, add_logging_args
logger = logging.getLogger("rev_gateway")
logger.setLevel(logging.INFO)
default_websocket = "ws://localhost:7650/out"
default_timeout = 600
class ReverseGateway:
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
pulsar_host: str = None, pulsar_api_key: str = None,
pulsar_listener: str = None):
# Set default WebSocket URI with environment variable support
def __init__(self, **config):
websocket_uri = config.get("websocket_uri")
if websocket_uri is None:
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
# Parse and validate the WebSocket URI
parsed_uri = urlparse(websocket_uri)
if parsed_uri.scheme not in ('ws', 'wss'):
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
raise ValueError(
f"WebSocket URI must use ws:// or wss:// scheme, "
f"got: {parsed_uri.scheme}"
)
if not parsed_uri.netloc:
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
# Store parsed components for debugging/logging
raise ValueError(
f"WebSocket URI must include hostname, "
f"got: {websocket_uri}"
)
self.websocket_uri = websocket_uri
self.host = parsed_uri.hostname
self.port = parsed_uri.port
self.scheme = parsed_uri.scheme
self.path = parsed_uri.path or "/ws"
# Construct the full URL (in case path was missing)
if not parsed_uri.path:
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
else:
self.url = websocket_uri
self.max_workers = max_workers
self.max_workers = int(config.get("max_workers", 10))
self.timeout = int(config.get("timeout", default_timeout))
self.ws: Optional[ClientWebSocketResponse] = None
self.session: Optional[ClientSession] = None
self.running = False
self.reconnect_delay = 3.0
# Pulsar configuration
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
self.pulsar_listener = pulsar_listener
# Create backend using factory
backend_params = {
'pulsar_host': self.pulsar_host,
'pulsar_api_key': self.pulsar_api_key,
'pulsar_listener': self.pulsar_listener,
}
self.backend = get_pubsub(**backend_params)
self.backend = get_pubsub(**config)
# Initialize config receiver
self.config_receiver = ConfigReceiver(self.backend)
self.auth = IamAuth(
backend=self.backend,
id=config.get("id", "rev-gateway"),
)
self.config_receiver = ConfigReceiver(
self.backend, auth=self.auth,
)
self.dispatcher = MessageDispatcher(
self.max_workers, self.config_receiver, self.backend,
auth=self.auth, timeout=self.timeout,
)
# Initialize dispatcher with config_receiver and backend - must be created after config_receiver
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend)
async def connect(self) -> bool:
try:
if self.session is None:
self.session = ClientSession()
logger.info(f"Connecting to {self.url}")
self.ws = await self.session.ws_connect(self.url)
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
logger.info(
f"WebSocket connection established to "
f"{self.host}:{self.port or 'default'}"
)
return True
except Exception as e:
logger.error(f"Failed to connect to {self.url}: {e}")
return False
async def disconnect(self):
if self.ws and not self.ws.closed:
await self.ws.close()
@ -92,32 +102,31 @@ class ReverseGateway:
await self.session.close()
self.ws = None
self.session = None
async def send_message(self, message: dict):
if self.ws and not self.ws.closed:
try:
await self.ws.send_str(json.dumps(message))
except Exception as e:
logger.error(f"Failed to send message: {e}")
async def handle_message(self, message: str):
try:
logger.debug(f"Received message: {message}")
msg_data = json.loads(message)
response = await self.dispatcher.handle_message(msg_data)
if response:
await self.send_message(response)
await self.dispatcher.handle_message(
msg_data, self.send_message,
)
except Exception as e:
logger.error(f"Error handling message: {e}")
async def listen(self):
while self.running and self.ws and not self.ws.closed:
try:
msg = await self.ws.receive()
if msg.type == WSMsgType.TEXT:
await self.handle_message(msg.data)
elif msg.type == WSMsgType.BINARY:
@ -125,31 +134,33 @@ class ReverseGateway:
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
logger.warning("WebSocket closed or error occurred")
break
except Exception as e:
logger.error(f"Error in listen loop: {e}")
break
async def run(self):
self.running = True
logger.info("Starting reverse gateway")
# Start config receiver
logger.info("Starting config receiver")
await self.auth.start()
await self.config_receiver.start()
while self.running:
try:
if await self.connect():
await self.listen()
else:
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
logger.warning(
f"Connection failed, retrying in "
f"{self.reconnect_delay} seconds"
)
await self.disconnect()
if self.running:
await asyncio.sleep(self.reconnect_delay)
except KeyboardInterrupt:
logger.info("Shutdown requested")
break
@ -157,77 +168,69 @@ class ReverseGateway:
logger.error(f"Unexpected error: {e}")
if self.running:
await asyncio.sleep(self.reconnect_delay)
await self.shutdown()
async def shutdown(self):
logger.info("Shutting down reverse gateway")
self.running = False
await self.dispatcher.shutdown()
await self.disconnect()
# Close backend
if hasattr(self, 'backend'):
self.backend.close()
def stop(self):
self.running = False
def parse_args():
def run():
parser = argparse.ArgumentParser(
prog="reverse-gateway",
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
description=__doc__,
)
parser.add_argument(
'--id',
default='rev-gateway',
help='Service identifier (default: rev-gateway)',
)
parser.add_argument(
'--websocket-uri',
default=None,
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)'
help=f'WebSocket URI to connect to (default: {default_websocket})',
)
parser.add_argument(
'--max-workers',
type=int,
default=10,
help='Maximum concurrent message handlers (default: 10)'
help='Maximum concurrent message handlers (default: 10)',
)
parser.add_argument(
'-p', '--pulsar-host',
default=None,
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
)
parser.add_argument(
'--pulsar-api-key',
default=None,
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
)
parser.add_argument(
'--pulsar-listener',
default=None,
help='Pulsar listener name'
)
return parser.parse_args()
def run():
args = parse_args()
gateway = ReverseGateway(
websocket_uri=args.websocket_uri,
max_workers=args.max_workers,
pulsar_host=args.pulsar_host,
pulsar_api_key=args.pulsar_api_key,
pulsar_listener=args.pulsar_listener
parser.add_argument(
'--timeout',
type=int,
default=default_timeout,
help=f'Request timeout in seconds (default: {default_timeout})',
)
add_pubsub_args(parser)
add_logging_args(parser)
args = parser.parse_args()
args = vars(args)
setup_logging(args)
gateway = ReverseGateway(**args)
logger.info(f"Starting reverse gateway:")
logger.info(f" WebSocket URI: {gateway.url}")
logger.info(f" Max workers: {args.max_workers}")
logger.info(f" Pulsar host: {gateway.pulsar_host}")
logger.info(f" Max workers: {gateway.max_workers}")
try:
asyncio.run(gateway.run())
except KeyboardInterrupt: