mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: add worker sync events
Add a worker sync event so that runtime updates on one worker can propagate across other workers using pubsub for multi worker deployments
This commit is contained in:
parent
56763a4527
commit
03df5595c3
18 changed files with 446 additions and 113 deletions
0
api/services/worker_sync/__init__.py
Normal file
0
api/services/worker_sync/__init__.py
Normal file
114
api/services/worker_sync/manager.py
Normal file
114
api/services/worker_sync/manager.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""Worker sync manager for cross-worker state propagation.
|
||||
|
||||
Each FastAPI worker both publishes and listens on a single Redis pub/sub
|
||||
channel. When shared state changes (e.g. Langfuse credentials), the worker
|
||||
that handled the mutation broadcasts a lightweight event. Every worker
|
||||
(including the sender) receives it and runs the registered handler, which
|
||||
re-reads authoritative state from the DB.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable, Dict
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
||||
from api.enums import RedisChannel
|
||||
from api.services.worker_sync.protocol import WorkerSyncEvent
|
||||
|
||||
SyncHandler = Callable[[WorkerSyncEvent], Awaitable[None]]
|
||||
|
||||
|
||||
class WorkerSyncManager:
|
||||
"""Propagates state changes across FastAPI workers via Redis pub/sub."""
|
||||
|
||||
def __init__(self, redis_url: str):
|
||||
self._redis_url = redis_url
|
||||
self._handlers: Dict[str, SyncHandler] = {}
|
||||
self._redis: aioredis.Redis | None = None
|
||||
self._pubsub: aioredis.client.PubSub | None = None
|
||||
self._listener_task: asyncio.Task | None = None
|
||||
|
||||
def register(self, event_type: str, handler: SyncHandler):
|
||||
"""Register a handler for an event type. Call before start()."""
|
||||
self._handlers[event_type] = handler
|
||||
logger.info(f"Worker sync handler registered: {event_type}")
|
||||
|
||||
async def broadcast(self, event_type: str, action: str, org_id: str = ""):
|
||||
"""Publish an event to all workers (including self)."""
|
||||
if not self._redis:
|
||||
logger.warning("WorkerSyncManager not started, skipping broadcast")
|
||||
return
|
||||
event = WorkerSyncEvent(event_type=event_type, action=action, org_id=org_id)
|
||||
await self._redis.publish(RedisChannel.WORKER_SYNC.value, event.to_json())
|
||||
logger.debug(f"Broadcast worker sync: {event_type}/{action} org={org_id}")
|
||||
|
||||
async def start(self):
|
||||
"""Open a dedicated Redis connection and start the background listener."""
|
||||
self._redis = await aioredis.from_url(self._redis_url, decode_responses=True)
|
||||
self._pubsub = self._redis.pubsub()
|
||||
await self._pubsub.subscribe(RedisChannel.WORKER_SYNC.value)
|
||||
self._listener_task = asyncio.create_task(self._listen())
|
||||
logger.info("WorkerSyncManager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Cancel the listener and close the Redis connection."""
|
||||
if self._listener_task:
|
||||
self._listener_task.cancel()
|
||||
try:
|
||||
await self._listener_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._pubsub:
|
||||
await self._pubsub.unsubscribe(RedisChannel.WORKER_SYNC.value)
|
||||
await self._pubsub.close()
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
logger.info("WorkerSyncManager stopped")
|
||||
|
||||
async def _listen(self):
|
||||
"""Background loop: receive events and dispatch to handlers."""
|
||||
try:
|
||||
async for message in self._pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
event = WorkerSyncEvent.from_json(message["data"])
|
||||
if not event:
|
||||
continue
|
||||
handler = self._handlers.get(event.event_type)
|
||||
if handler:
|
||||
try:
|
||||
await handler(event)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Worker sync handler error: {event.event_type}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No handler for worker sync event: {event.event_type}"
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Worker sync listener crashed")
|
||||
|
||||
|
||||
# Module-level singleton, initialized in app lifespan
|
||||
_manager: WorkerSyncManager | None = None
|
||||
|
||||
|
||||
def get_worker_sync_manager() -> WorkerSyncManager:
|
||||
"""Get the active WorkerSyncManager instance.
|
||||
|
||||
Raises RuntimeError if called before the manager is started (i.e. outside
|
||||
the FastAPI lifespan).
|
||||
"""
|
||||
if _manager is None:
|
||||
raise RuntimeError("WorkerSyncManager not initialized")
|
||||
return _manager
|
||||
|
||||
|
||||
def set_worker_sync_manager(manager: WorkerSyncManager):
|
||||
"""Set the module-level singleton. Called from the app lifespan."""
|
||||
global _manager
|
||||
_manager = manager
|
||||
48
api/services/worker_sync/protocol.py
Normal file
48
api/services/worker_sync/protocol.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Worker sync event protocol.
|
||||
|
||||
Defines the message format for cross-worker state synchronization via Redis pub/sub.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class WorkerSyncEventType(str, Enum):
|
||||
"""Types of worker sync events."""
|
||||
|
||||
LANGFUSE_CREDENTIALS = "langfuse_credentials"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerSyncEvent:
|
||||
"""A notification that some shared state has changed.
|
||||
|
||||
Handlers should re-read authoritative state from the DB rather than
|
||||
relying on fields in the event — the event is just a trigger.
|
||||
"""
|
||||
|
||||
event_type: str # handler key, e.g. "langfuse_credentials"
|
||||
action: str # "update" or "delete"
|
||||
org_id: str = ""
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp is None:
|
||||
from datetime import UTC, datetime
|
||||
|
||||
self.timestamp = datetime.now(UTC).isoformat()
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: str) -> Optional["WorkerSyncEvent"]:
|
||||
try:
|
||||
return cls(**json.loads(data))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse worker sync event: {e}, data: {data}")
|
||||
return None
|
||||
Loading…
Add table
Add a link
Reference in a new issue