feat: add worker sync events

Add a worker sync event so that runtime updates on one worker can propagate across other workers using pubsub for multi worker deployments
This commit is contained in:
Abhishek Kumar 2026-04-04 14:26:47 +05:30
parent 56763a4527
commit 03df5595c3
18 changed files with 446 additions and 113 deletions

View file

@ -20,15 +20,15 @@ api/
## Where to Find Things
| Looking for... | Go to... |
|----------------|----------|
| API endpoints | `routes/` - each file is a router module, aggregated in `routes/main.py` |
| Business logic | `services/` - organized by domain (telephony, workflow, campaign, etc.) |
| Database models | `db/models.py` |
| Database queries | `db/*_client.py` files (repository pattern) |
| Request/response types | `schemas/` |
| Background tasks | `tasks/` - uses ARQ for async job processing |
| Environment config | `constants.py` |
| Looking for... | Go to... |
| ---------------------- | ------------------------------------------------------------------------ |
| API endpoints | `routes/` - each file is a router module, aggregated in `routes/main.py` |
| Business logic | `services/` - organized by domain (telephony, workflow, campaign, etc.) |
| Database models | `db/models.py` |
| Database queries | `db/*_client.py` files (repository pattern) |
| Request/response types | `schemas/` |
| Background tasks | `tasks/` - uses ARQ for async job processing |
| Environment config | `constants.py` |
## API Structure
@ -43,6 +43,10 @@ api/
./scripts/migrate.sh # Run migrations
```
## Cross-Worker State Sync
When an API endpoint updates in-memory state (e.g. cached credentials, config objects), that change only affects the worker process that handled the request. With multiple FastAPI workers, **use `WorkerSyncManager`** (`services/worker_sync/`) to propagate changes to all workers via Redis pub/sub instead of updating local state directly.
## Development
```bash

View file

@ -26,8 +26,17 @@ from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from api.constants import REDIS_URL
from api.routes.main import router as main_router
from api.services.pipecat.tracing_config import load_all_org_langfuse_credentials
from api.services.pipecat.tracing_config import (
handle_langfuse_sync,
load_all_org_langfuse_credentials,
)
from api.services.worker_sync.manager import (
WorkerSyncManager,
set_worker_sync_manager,
)
from api.services.worker_sync.protocol import WorkerSyncEventType
from api.tasks.arq import get_arq_redis
API_PREFIX = "/api/v1"
@ -42,10 +51,19 @@ async def lifespan(app: FastAPI):
# before any pipeline runs, without per-call DB lookups.
await load_all_org_langfuse_credentials()
# Start cross-worker sync manager so config changes propagate to all workers
sync_manager = WorkerSyncManager(REDIS_URL)
sync_manager.register(
WorkerSyncEventType.LANGFUSE_CREDENTIALS, handle_langfuse_sync
)
await sync_manager.start()
set_worker_sync_manager(sync_manager)
yield # Run app
# Shutdown sequence - this runs when FastAPI is shutting down
logger.info("Starting graceful shutdown...")
await sync_manager.stop()
app = FastAPI(

View file

@ -493,6 +493,7 @@ class KnowledgeBaseClient(BaseDBClient):
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".doc": "application/msword",
".txt": "text/plain",
".json": "application/json",
".html": "text/html",
".md": "text/markdown",
}

View file

@ -103,6 +103,7 @@ class RedisChannel(Enum):
"""Redis pub/sub channel names"""
CAMPAIGN_EVENTS = "campaign_events"
WORKER_SYNC = "worker_sync"
class TriggerState(Enum):

View file

@ -24,7 +24,8 @@ from api.schemas.telephony_config import (
)
from api.services.auth.depends import get_user
from api.services.configuration.masking import is_mask_of, mask_key
from api.services.pipecat.tracing_config import unregister_org_langfuse_credentials
from api.services.worker_sync.manager import get_worker_sync_manager
from api.services.worker_sync.protocol import WorkerSyncEventType
router = APIRouter(prefix="/organizations", tags=["organizations"])
@ -341,14 +342,11 @@ async def save_langfuse_credentials(
config_value,
)
# Update the in-memory OTEL exporter so new traces route immediately
from api.services.pipecat.tracing_config import register_org_langfuse_credentials
register_org_langfuse_credentials(
# Broadcast to all workers so every process updates its in-memory exporter
await get_worker_sync_manager().broadcast(
WorkerSyncEventType.LANGFUSE_CREDENTIALS,
action="update",
org_id=user.selected_organization_id,
host=config_value["host"],
public_key=config_value["public_key"],
secret_key=config_value["secret_key"],
)
return {"message": "Langfuse credentials saved successfully"}
@ -368,8 +366,12 @@ async def delete_langfuse_credentials(user: UserModel = Depends(get_user)):
if not deleted:
raise HTTPException(status_code=404, detail="No Langfuse credentials found")
# Remove the in-memory OTEL exporter so traces fall back to default
unregister_org_langfuse_credentials(user.selected_organization_id)
# Broadcast to all workers so every process removes its in-memory exporter
await get_worker_sync_manager().broadcast(
WorkerSyncEventType.LANGFUSE_CREDENTIALS,
action="delete",
org_id=user.selected_organization_id,
)
return {"message": "Langfuse credentials deleted successfully"}

View file

@ -223,6 +223,36 @@ async def load_all_org_langfuse_credentials():
logger.info(f"Loaded Langfuse credentials for {len(configs)} org(s)")
async def handle_langfuse_sync(event):
"""Worker sync handler: refresh a single org's Langfuse exporter from DB."""
from api.db import db_client
from api.enums import OrganizationConfigurationKey
org_id = event.org_id
logger.info(
f"handle_langfuse_sync for org_id: {event.org_id} action: {event.action}"
)
if event.action == "delete":
unregister_org_langfuse_credentials(org_id)
return
config = await db_client.get_configuration(
org_id, OrganizationConfigurationKey.LANGFUSE_CREDENTIALS.value
)
if config and config.value:
register_org_langfuse_credentials(
org_id=org_id,
host=config.value.get("host"),
public_key=config.value.get("public_key"),
secret_key=config.value.get("secret_key"),
)
else:
# Credentials were saved then deleted before we got the event
unregister_org_langfuse_credentials(org_id)
def get_trace_url(trace_id: str, org_id=None) -> str | None:
"""Build a Langfuse trace URL, using org-specific host when available."""
if org_id is None:

View file

View file

@ -0,0 +1,114 @@
"""Worker sync manager for cross-worker state propagation.
Each FastAPI worker both publishes and listens on a single Redis pub/sub
channel. When shared state changes (e.g. Langfuse credentials), the worker
that handled the mutation broadcasts a lightweight event. Every worker
(including the sender) receives it and runs the registered handler, which
re-reads authoritative state from the DB.
"""
import asyncio
from typing import Awaitable, Callable, Dict
import redis.asyncio as aioredis
from loguru import logger
from api.enums import RedisChannel
from api.services.worker_sync.protocol import WorkerSyncEvent
SyncHandler = Callable[[WorkerSyncEvent], Awaitable[None]]
class WorkerSyncManager:
"""Propagates state changes across FastAPI workers via Redis pub/sub."""
def __init__(self, redis_url: str):
self._redis_url = redis_url
self._handlers: Dict[str, SyncHandler] = {}
self._redis: aioredis.Redis | None = None
self._pubsub: aioredis.client.PubSub | None = None
self._listener_task: asyncio.Task | None = None
def register(self, event_type: str, handler: SyncHandler):
"""Register a handler for an event type. Call before start()."""
self._handlers[event_type] = handler
logger.info(f"Worker sync handler registered: {event_type}")
async def broadcast(self, event_type: str, action: str, org_id: str = ""):
"""Publish an event to all workers (including self)."""
if not self._redis:
logger.warning("WorkerSyncManager not started, skipping broadcast")
return
event = WorkerSyncEvent(event_type=event_type, action=action, org_id=org_id)
await self._redis.publish(RedisChannel.WORKER_SYNC.value, event.to_json())
logger.debug(f"Broadcast worker sync: {event_type}/{action} org={org_id}")
async def start(self):
"""Open a dedicated Redis connection and start the background listener."""
self._redis = await aioredis.from_url(self._redis_url, decode_responses=True)
self._pubsub = self._redis.pubsub()
await self._pubsub.subscribe(RedisChannel.WORKER_SYNC.value)
self._listener_task = asyncio.create_task(self._listen())
logger.info("WorkerSyncManager started")
async def stop(self):
"""Cancel the listener and close the Redis connection."""
if self._listener_task:
self._listener_task.cancel()
try:
await self._listener_task
except asyncio.CancelledError:
pass
if self._pubsub:
await self._pubsub.unsubscribe(RedisChannel.WORKER_SYNC.value)
await self._pubsub.close()
if self._redis:
await self._redis.close()
logger.info("WorkerSyncManager stopped")
async def _listen(self):
"""Background loop: receive events and dispatch to handlers."""
try:
async for message in self._pubsub.listen():
if message["type"] != "message":
continue
event = WorkerSyncEvent.from_json(message["data"])
if not event:
continue
handler = self._handlers.get(event.event_type)
if handler:
try:
await handler(event)
except Exception:
logger.exception(
f"Worker sync handler error: {event.event_type}"
)
else:
logger.warning(
f"No handler for worker sync event: {event.event_type}"
)
except asyncio.CancelledError:
raise
except Exception:
logger.exception("Worker sync listener crashed")
# Module-level singleton, initialized in app lifespan
_manager: WorkerSyncManager | None = None
def get_worker_sync_manager() -> WorkerSyncManager:
"""Get the active WorkerSyncManager instance.
Raises RuntimeError if called before the manager is started (i.e. outside
the FastAPI lifespan).
"""
if _manager is None:
raise RuntimeError("WorkerSyncManager not initialized")
return _manager
def set_worker_sync_manager(manager: WorkerSyncManager):
"""Set the module-level singleton. Called from the app lifespan."""
global _manager
_manager = manager

View file

@ -0,0 +1,48 @@
"""Worker sync event protocol.
Defines the message format for cross-worker state synchronization via Redis pub/sub.
"""
import json
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Optional
from loguru import logger
class WorkerSyncEventType(str, Enum):
"""Types of worker sync events."""
LANGFUSE_CREDENTIALS = "langfuse_credentials"
@dataclass
class WorkerSyncEvent:
"""A notification that some shared state has changed.
Handlers should re-read authoritative state from the DB rather than
relying on fields in the event the event is just a trigger.
"""
event_type: str # handler key, e.g. "langfuse_credentials"
action: str # "update" or "delete"
org_id: str = ""
timestamp: Optional[str] = None
def __post_init__(self):
if self.timestamp is None:
from datetime import UTC, datetime
self.timestamp = datetime.now(UTC).isoformat()
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, data: str) -> Optional["WorkerSyncEvent"]:
try:
return cls(**json.loads(data))
except Exception as e:
logger.error(f"Failed to parse worker sync event: {e}, data: {data}")
return None

View file

@ -1,5 +1,6 @@
"""ARQ background task for processing knowledge base documents."""
import json
import os
import tempfile
@ -163,84 +164,148 @@ async def process_knowledge_base_document(
base_url=embeddings_base_url,
)
# Step 1: Convert document with docling
logger.info("Converting document with docling")
converter = DocumentConverter()
conversion_result = converter.convert(temp_file_path)
doc = conversion_result.document
# Store docling metadata
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Step 2: Initialize tokenizer for chunking
# Step 1: Initialize tokenizer for chunking
logger.info(
f"Loading tokenizer: {TOKENIZER_MODEL} with max_tokens={max_tokens}"
)
hf_tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL)
tokenizer = HuggingFaceTokenizer(
tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_MODEL),
tokenizer=hf_tokenizer,
max_tokens=max_tokens,
)
# Step 3: Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
chunker = HybridChunker(tokenizer=tokenizer)
# Step 4: Chunk the document
logger.info(f"Chunking document with max_tokens={max_tokens}")
chunks = list(chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Step 5: Process each chunk
chunk_texts = []
chunk_records = []
token_counts = []
for i, chunk in enumerate(chunks):
chunk_text = chunk.text
contextualized_text = chunker.contextualize(chunk=chunk)
# Check if file is a plain text format that docling doesn't support
plain_text_extensions = {".txt", ".json"}
if file_extension.lower() in plain_text_extensions:
# Read text content directly
logger.info(f"Reading {file_extension} file directly (bypassing docling)")
with open(temp_file_path, "r", encoding="utf-8") as f:
raw_content = f.read()
# Calculate token count
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
tokenizer.tokenizer.encode(text_to_tokenize, add_special_tokens=False)
)
token_counts.append(token_count)
# For JSON files, pretty-print for better readability
if file_extension.lower() == ".json":
try:
parsed = json.loads(raw_content)
raw_content = json.dumps(parsed, indent=2, ensure_ascii=False)
except json.JSONDecodeError:
logger.warning(
"JSON file is not valid JSON, treating as plain text"
)
# Prepare chunk metadata
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings if hasattr(chunk.meta, "headings") else []
),
}
docling_metadata = {
"num_pages": None,
"document_type": "PlainText",
}
# Create chunk record (without embedding yet)
chunk_record = KnowledgeBaseChunkModel(
document_id=document_id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=service.get_model_id(),
embedding_dimension=service.get_embedding_dimension(),
token_count=token_count,
# Token-based chunking for plain text
tokens = hf_tokenizer.encode(raw_content, add_special_tokens=False)
total_tokens = len(tokens)
logger.info(
f"Total tokens in file: {total_tokens}, chunking with max_tokens={max_tokens}"
)
chunk_records.append(chunk_record)
chunk_texts.append(text_to_tokenize)
start = 0
chunk_index = 0
while start < total_tokens:
end = min(start + max_tokens, total_tokens)
chunk_token_ids = tokens[start:end]
chunk_text = hf_tokenizer.decode(
chunk_token_ids, skip_special_tokens=True
)
token_count = len(chunk_token_ids)
token_counts.append(token_count)
chunk_record = KnowledgeBaseChunkModel(
document_id=document_id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=chunk_text,
chunk_index=chunk_index,
chunk_metadata={},
embedding_model=service.get_model_id(),
embedding_dimension=service.get_embedding_dimension(),
token_count=token_count,
)
chunk_records.append(chunk_record)
chunk_texts.append(chunk_text)
chunk_index += 1
start = end
total_chunks = len(chunk_records)
logger.info(f"Generated {total_chunks} chunks from plain text")
else:
# Use docling for structured formats (PDF, DOCX, etc.)
logger.info("Converting document with docling")
converter = DocumentConverter()
conversion_result = converter.convert(temp_file_path)
doc = conversion_result.document
docling_metadata = {
"num_pages": len(doc.pages) if hasattr(doc, "pages") else None,
"document_type": type(doc).__name__,
}
# Initialize chunker
logger.info(f"Initializing HybridChunker with max_tokens={max_tokens}")
chunker = HybridChunker(tokenizer=tokenizer)
# Chunk the document
logger.info(f"Chunking document with max_tokens={max_tokens}")
chunks = list(chunker.chunk(dl_doc=doc))
total_chunks = len(chunks)
logger.info(f"Generated {total_chunks} chunks")
# Process each chunk
for i, chunk in enumerate(chunks):
chunk_text = chunk.text
contextualized_text = chunker.contextualize(chunk=chunk)
text_to_tokenize = (
contextualized_text if contextualized_text else chunk_text
)
token_count = len(
tokenizer.tokenizer.encode(
text_to_tokenize, add_special_tokens=False
)
)
token_counts.append(token_count)
chunk_metadata = {}
if hasattr(chunk, "meta") and chunk.meta:
chunk_metadata = {
"doc_items": (
[str(item) for item in chunk.meta.doc_items]
if hasattr(chunk.meta, "doc_items")
else []
),
"headings": (
chunk.meta.headings
if hasattr(chunk.meta, "headings")
else []
),
}
chunk_record = KnowledgeBaseChunkModel(
document_id=document_id,
organization_id=organization_id,
chunk_text=chunk_text,
contextualized_text=contextualized_text,
chunk_index=i,
chunk_metadata=chunk_metadata,
embedding_model=service.get_model_id(),
embedding_dimension=service.get_embedding_dimension(),
token_count=token_count,
)
chunk_records.append(chunk_record)
chunk_texts.append(text_to_tokenize)
# Log chunk statistics
if token_counts: