mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-28 00:35:13 +02:00
Update rev-gateway for IAM integration (#940)
service.py: - Constructor takes **config (same pattern as api-gateway) instead of individual args - Creates IamAuth and calls await self.auth.start() before the message loop - Passes auth to both ConfigReceiver and MessageDispatcher - Uses add_pubsub_args / add_logging_args instead of hand-rolled Pulsar args - Passes timeout through dispatcher.py: - Accepts auth and timeout parameters - Passes both to DispatcherManager — fixes the missing auth argument that would have crashed on startup The remote end's requests now go through the same IAM authentication path as api-gateway. Token validation, workspace resolution, and permissions all work identically regardless of which direction initiated the connection. Fixed tests — the test now passes auth and timeout to MessageDispatcher and verifies they're forwarded to DispatcherManager. Update rev gateway dispatcher to align with IAM. A "token" parameter must be passed with each message. Fix websocket relay to align with rev-gateway changes, conforms to the api-gateway protocol.
This commit is contained in:
parent
4e3bd85abc
commit
e57f4669e1
7 changed files with 914 additions and 865 deletions
|
|
@ -25,16 +25,17 @@ class TestSemaphoreEnforcement:
|
|||
max_concurrent = 0
|
||||
processing_event = asyncio.Event()
|
||||
|
||||
async def slow_process(message):
|
||||
async def slow_process(message, sender):
|
||||
nonlocal concurrent_count, max_concurrent
|
||||
concurrent_count += 1
|
||||
max_concurrent = max(max_concurrent, concurrent_count)
|
||||
await asyncio.sleep(0.05)
|
||||
concurrent_count -= 1
|
||||
return {"id": message.get("id"), "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = slow_process
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
# Launch more tasks than max_workers
|
||||
messages = [
|
||||
{"id": f"msg-{i}", "service": "test", "request": {}}
|
||||
|
|
@ -42,7 +43,7 @@ class TestSemaphoreEnforcement:
|
|||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(dispatcher.handle_message(m))
|
||||
asyncio.create_task(dispatcher.handle_message(m, sender))
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
|
@ -66,17 +67,17 @@ class TestSemaphoreEnforcement:
|
|||
|
||||
original_process = dispatcher._process_message
|
||||
|
||||
async def tracking_process(message):
|
||||
async def tracking_process(message, sender):
|
||||
nonlocal task_was_tracked
|
||||
# During processing, our task should be in active_tasks
|
||||
if len(dispatcher.active_tasks) > 0:
|
||||
task_was_tracked = True
|
||||
return {"id": message.get("id"), "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = tracking_process
|
||||
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test", "service": "test", "request": {}}
|
||||
{"id": "test", "service": "test", "request": {}},
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
assert task_was_tracked
|
||||
|
|
@ -88,7 +89,7 @@ class TestSemaphoreEnforcement:
|
|||
"""Semaphore should be released even if processing raises."""
|
||||
dispatcher = MessageDispatcher(max_workers=2)
|
||||
|
||||
async def failing_process(message):
|
||||
async def failing_process(message, sender):
|
||||
raise RuntimeError("process failed")
|
||||
|
||||
dispatcher._process_message = failing_process
|
||||
|
|
@ -96,7 +97,8 @@ class TestSemaphoreEnforcement:
|
|||
# Should not deadlock — semaphore must be released on error
|
||||
with pytest.raises(RuntimeError):
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test", "service": "test", "request": {}}
|
||||
{"id": "test", "service": "test", "request": {}},
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
# Semaphore should be back at max
|
||||
|
|
@ -109,17 +111,18 @@ class TestSemaphoreEnforcement:
|
|||
|
||||
order = []
|
||||
|
||||
async def ordered_process(message):
|
||||
async def ordered_process(message, sender):
|
||||
msg_id = message["id"]
|
||||
order.append(f"start-{msg_id}")
|
||||
await asyncio.sleep(0.02)
|
||||
order.append(f"end-{msg_id}")
|
||||
return {"id": msg_id, "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = ordered_process
|
||||
|
||||
sender = AsyncMock()
|
||||
|
||||
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
|
||||
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
|
||||
tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# With semaphore=1, each message should complete before next starts
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue