mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
Add a worker sync event so that runtime updates on one worker can propagate across other workers using pubsub for multi worker deployments
114 lines
4.3 KiB
Python
114 lines
4.3 KiB
Python
"""Worker sync manager for cross-worker state propagation.
|
|
|
|
Each FastAPI worker both publishes and listens on a single Redis pub/sub
|
|
channel. When shared state changes (e.g. Langfuse credentials), the worker
|
|
that handled the mutation broadcasts a lightweight event. Every worker
|
|
(including the sender) receives it and runs the registered handler, which
|
|
re-reads authoritative state from the DB.
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Awaitable, Callable, Dict
|
|
|
|
import redis.asyncio as aioredis
|
|
from loguru import logger
|
|
|
|
from api.enums import RedisChannel
|
|
from api.services.worker_sync.protocol import WorkerSyncEvent
|
|
|
|
SyncHandler = Callable[[WorkerSyncEvent], Awaitable[None]]
|
|
|
|
|
|
class WorkerSyncManager:
|
|
"""Propagates state changes across FastAPI workers via Redis pub/sub."""
|
|
|
|
def __init__(self, redis_url: str):
|
|
self._redis_url = redis_url
|
|
self._handlers: Dict[str, SyncHandler] = {}
|
|
self._redis: aioredis.Redis | None = None
|
|
self._pubsub: aioredis.client.PubSub | None = None
|
|
self._listener_task: asyncio.Task | None = None
|
|
|
|
def register(self, event_type: str, handler: SyncHandler):
|
|
"""Register a handler for an event type. Call before start()."""
|
|
self._handlers[event_type] = handler
|
|
logger.info(f"Worker sync handler registered: {event_type}")
|
|
|
|
async def broadcast(self, event_type: str, action: str, org_id: str = ""):
|
|
"""Publish an event to all workers (including self)."""
|
|
if not self._redis:
|
|
logger.warning("WorkerSyncManager not started, skipping broadcast")
|
|
return
|
|
event = WorkerSyncEvent(event_type=event_type, action=action, org_id=org_id)
|
|
await self._redis.publish(RedisChannel.WORKER_SYNC.value, event.to_json())
|
|
logger.debug(f"Broadcast worker sync: {event_type}/{action} org={org_id}")
|
|
|
|
async def start(self):
|
|
"""Open a dedicated Redis connection and start the background listener."""
|
|
self._redis = await aioredis.from_url(self._redis_url, decode_responses=True)
|
|
self._pubsub = self._redis.pubsub()
|
|
await self._pubsub.subscribe(RedisChannel.WORKER_SYNC.value)
|
|
self._listener_task = asyncio.create_task(self._listen())
|
|
logger.info("WorkerSyncManager started")
|
|
|
|
async def stop(self):
|
|
"""Cancel the listener and close the Redis connection."""
|
|
if self._listener_task:
|
|
self._listener_task.cancel()
|
|
try:
|
|
await self._listener_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
if self._pubsub:
|
|
await self._pubsub.unsubscribe(RedisChannel.WORKER_SYNC.value)
|
|
await self._pubsub.close()
|
|
if self._redis:
|
|
await self._redis.close()
|
|
logger.info("WorkerSyncManager stopped")
|
|
|
|
async def _listen(self):
|
|
"""Background loop: receive events and dispatch to handlers."""
|
|
try:
|
|
async for message in self._pubsub.listen():
|
|
if message["type"] != "message":
|
|
continue
|
|
event = WorkerSyncEvent.from_json(message["data"])
|
|
if not event:
|
|
continue
|
|
handler = self._handlers.get(event.event_type)
|
|
if handler:
|
|
try:
|
|
await handler(event)
|
|
except Exception:
|
|
logger.exception(
|
|
f"Worker sync handler error: {event.event_type}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"No handler for worker sync event: {event.event_type}"
|
|
)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
logger.exception("Worker sync listener crashed")
|
|
|
|
|
|
# Module-level singleton, initialized in app lifespan
|
|
_manager: WorkerSyncManager | None = None
|
|
|
|
|
|
def get_worker_sync_manager() -> WorkerSyncManager:
|
|
"""Get the active WorkerSyncManager instance.
|
|
|
|
Raises RuntimeError if called before the manager is started (i.e. outside
|
|
the FastAPI lifespan).
|
|
"""
|
|
if _manager is None:
|
|
raise RuntimeError("WorkerSyncManager not initialized")
|
|
return _manager
|
|
|
|
|
|
def set_worker_sync_manager(manager: WorkerSyncManager):
|
|
"""Set the module-level singleton. Called from the app lifespan."""
|
|
global _manager
|
|
_manager = manager
|