feat: add worker sync events

Add a worker sync event so that runtime updates on one worker can propagate across other workers using pubsub for multi worker deployments
This commit is contained in:
Abhishek Kumar 2026-04-04 14:26:47 +05:30
parent 56763a4527
commit 03df5595c3
18 changed files with 446 additions and 113 deletions

View file

@ -223,6 +223,36 @@ async def load_all_org_langfuse_credentials():
logger.info(f"Loaded Langfuse credentials for {len(configs)} org(s)")
async def handle_langfuse_sync(event):
"""Worker sync handler: refresh a single org's Langfuse exporter from DB."""
from api.db import db_client
from api.enums import OrganizationConfigurationKey
org_id = event.org_id
logger.info(
f"handle_langfuse_sync for org_id: {event.org_id} action: {event.action}"
)
if event.action == "delete":
unregister_org_langfuse_credentials(org_id)
return
config = await db_client.get_configuration(
org_id, OrganizationConfigurationKey.LANGFUSE_CREDENTIALS.value
)
if config and config.value:
register_org_langfuse_credentials(
org_id=org_id,
host=config.value.get("host"),
public_key=config.value.get("public_key"),
secret_key=config.value.get("secret_key"),
)
else:
# Credentials were saved then deleted before we got the event
unregister_org_langfuse_credentials(org_id)
def get_trace_url(trace_id: str, org_id=None) -> str | None:
"""Build a Langfuse trace URL, using org-specific host when available."""
if org_id is None:

View file

View file

@ -0,0 +1,114 @@
"""Worker sync manager for cross-worker state propagation.
Each FastAPI worker both publishes and listens on a single Redis pub/sub
channel. When shared state changes (e.g. Langfuse credentials), the worker
that handled the mutation broadcasts a lightweight event. Every worker
(including the sender) receives it and runs the registered handler, which
re-reads authoritative state from the DB.
"""
import asyncio
from typing import Awaitable, Callable, Dict
import redis.asyncio as aioredis
from loguru import logger
from api.enums import RedisChannel
from api.services.worker_sync.protocol import WorkerSyncEvent
SyncHandler = Callable[[WorkerSyncEvent], Awaitable[None]]
class WorkerSyncManager:
"""Propagates state changes across FastAPI workers via Redis pub/sub."""
def __init__(self, redis_url: str):
self._redis_url = redis_url
self._handlers: Dict[str, SyncHandler] = {}
self._redis: aioredis.Redis | None = None
self._pubsub: aioredis.client.PubSub | None = None
self._listener_task: asyncio.Task | None = None
def register(self, event_type: str, handler: SyncHandler):
"""Register a handler for an event type. Call before start()."""
self._handlers[event_type] = handler
logger.info(f"Worker sync handler registered: {event_type}")
async def broadcast(self, event_type: str, action: str, org_id: str = ""):
"""Publish an event to all workers (including self)."""
if not self._redis:
logger.warning("WorkerSyncManager not started, skipping broadcast")
return
event = WorkerSyncEvent(event_type=event_type, action=action, org_id=org_id)
await self._redis.publish(RedisChannel.WORKER_SYNC.value, event.to_json())
logger.debug(f"Broadcast worker sync: {event_type}/{action} org={org_id}")
async def start(self):
"""Open a dedicated Redis connection and start the background listener."""
self._redis = await aioredis.from_url(self._redis_url, decode_responses=True)
self._pubsub = self._redis.pubsub()
await self._pubsub.subscribe(RedisChannel.WORKER_SYNC.value)
self._listener_task = asyncio.create_task(self._listen())
logger.info("WorkerSyncManager started")
async def stop(self):
"""Cancel the listener and close the Redis connection."""
if self._listener_task:
self._listener_task.cancel()
try:
await self._listener_task
except asyncio.CancelledError:
pass
if self._pubsub:
await self._pubsub.unsubscribe(RedisChannel.WORKER_SYNC.value)
await self._pubsub.close()
if self._redis:
await self._redis.close()
logger.info("WorkerSyncManager stopped")
async def _listen(self):
"""Background loop: receive events and dispatch to handlers."""
try:
async for message in self._pubsub.listen():
if message["type"] != "message":
continue
event = WorkerSyncEvent.from_json(message["data"])
if not event:
continue
handler = self._handlers.get(event.event_type)
if handler:
try:
await handler(event)
except Exception:
logger.exception(
f"Worker sync handler error: {event.event_type}"
)
else:
logger.warning(
f"No handler for worker sync event: {event.event_type}"
)
except asyncio.CancelledError:
raise
except Exception:
logger.exception("Worker sync listener crashed")
# Module-level singleton, initialized in app lifespan
_manager: WorkerSyncManager | None = None
def get_worker_sync_manager() -> WorkerSyncManager:
"""Get the active WorkerSyncManager instance.
Raises RuntimeError if called before the manager is started (i.e. outside
the FastAPI lifespan).
"""
if _manager is None:
raise RuntimeError("WorkerSyncManager not initialized")
return _manager
def set_worker_sync_manager(manager: WorkerSyncManager):
"""Set the module-level singleton. Called from the app lifespan."""
global _manager
_manager = manager

View file

@ -0,0 +1,48 @@
"""Worker sync event protocol.
Defines the message format for cross-worker state synchronization via Redis pub/sub.
"""
import json
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Optional
from loguru import logger
class WorkerSyncEventType(str, Enum):
"""Types of worker sync events."""
LANGFUSE_CREDENTIALS = "langfuse_credentials"
@dataclass
class WorkerSyncEvent:
"""A notification that some shared state has changed.
Handlers should re-read authoritative state from the DB rather than
relying on fields in the event the event is just a trigger.
"""
event_type: str # handler key, e.g. "langfuse_credentials"
action: str # "update" or "delete"
org_id: str = ""
timestamp: Optional[str] = None
def __post_init__(self):
if self.timestamp is None:
from datetime import UTC, datetime
self.timestamp = datetime.now(UTC).isoformat()
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, data: str) -> Optional["WorkerSyncEvent"]:
try:
return cls(**json.loads(data))
except Exception as e:
logger.error(f"Failed to parse worker sync event: {e}, data: {data}")
return None