feat: implement Redis heartbeat tracking for connector indexing tasks and update stale notification cleanup logic

This commit is contained in:
Anish Sarkar 2026-02-02 00:18:47 +05:30
parent 085653d3e3
commit 05d1d6ac04
2 changed files with 139 additions and 69 deletions

View file

@ -19,10 +19,12 @@ Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search spa
""" """
import logging import logging
import os
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Any from typing import Any
import pytz import pytz
import redis
from dateutil.parser import isoparse from dateutil.parser import isoparse
from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi import APIRouter, Body, Depends, HTTPException, Query
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
@ -78,6 +80,27 @@ from app.utils.rbac import check_permission
# Set up logging # Set up logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Redis client for heartbeat tracking
_heartbeat_redis_client: redis.Redis | None = None
# Redis key TTL - notification is stale if no heartbeat in this time
HEARTBEAT_TTL_SECONDS = 120 # 2 minutes
def get_heartbeat_redis_client() -> redis.Redis:
"""Get or create Redis client for heartbeat tracking."""
global _heartbeat_redis_client
if _heartbeat_redis_client is None:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
_heartbeat_redis_client = redis.from_url(redis_url, decode_responses=True)
return _heartbeat_redis_client
def _get_heartbeat_key(notification_id: int) -> str:
"""Generate Redis key for notification heartbeat."""
return f"indexing:heartbeat:{notification_id}"
router = APIRouter() router = APIRouter()
@ -1200,6 +1223,16 @@ async def _run_indexing_with_notifications(
) )
) )
# Set initial Redis heartbeat for stale detection
if notification:
try:
heartbeat_key = _get_heartbeat_key(notification.id)
get_heartbeat_redis_client().setex(
heartbeat_key, HEARTBEAT_TTL_SECONDS, "0"
)
except Exception as e:
logger.warning(f"Failed to set initial Redis heartbeat: {e}")
# Update notification to fetching stage # Update notification to fetching stage
if notification: if notification:
await NotificationService.connector_indexing.notify_indexing_progress( await NotificationService.connector_indexing.notify_indexing_progress(
@ -1241,6 +1274,17 @@ async def _run_indexing_with_notifications(
current_indexed_count = indexed_count current_indexed_count = indexed_count
if notification: if notification:
try: try:
# Set Redis heartbeat key with TTL (fast, for stale detection)
heartbeat_key = _get_heartbeat_key(notification.id)
get_heartbeat_redis_client().setex(
heartbeat_key, HEARTBEAT_TTL_SECONDS, str(indexed_count)
)
except Exception as e:
# Don't let Redis errors break the indexing
logger.warning(f"Failed to set Redis heartbeat: {e}")
try:
# Still update DB notification for progress display
await session.refresh(notification) await session.refresh(notification)
await ( await (
NotificationService.connector_indexing.notify_indexing_progress( NotificationService.connector_indexing.notify_indexing_progress(
@ -1473,6 +1517,14 @@ async def _run_indexing_with_notifications(
) )
except Exception as notif_error: except Exception as notif_error:
logger.error(f"Failed to update notification: {notif_error!s}") logger.error(f"Failed to update notification: {notif_error!s}")
finally:
# Clean up Redis heartbeat key when task completes (success or failure)
if notification:
try:
heartbeat_key = _get_heartbeat_key(notification.id)
get_heartbeat_redis_client().delete(heartbeat_key)
except Exception:
pass # Ignore cleanup errors - key will expire anyway
async def run_notion_indexing_with_new_session( async def run_notion_indexing_with_new_session(

View file

@ -1,18 +1,25 @@
"""Celery task to detect and mark stale connector indexing notifications as failed. """Celery task to detect and mark stale connector indexing notifications as failed.
This task runs periodically (every 5 minutes by default) to find notifications This task runs periodically (every 5 minutes by default) to find notifications
that are stuck in "in_progress" status but haven't received a heartbeat update that are stuck in "in_progress" status but don't have an active Redis heartbeat key.
in the configured timeout period. These are marked as "failed" to prevent the These are marked as "failed" to prevent the frontend from showing a perpetual "syncing" state.
frontend from showing a perpetual "syncing" state.
Detection mechanism:
- Active indexing tasks set a Redis key with TTL (2 minutes) as a heartbeat
- If the task crashes, the Redis key expires automatically
- This cleanup task checks for in-progress notifications without a Redis heartbeat key
- Such notifications are marked as failed with O(1) batch UPDATE
""" """
import json
import logging import logging
from datetime import UTC, datetime, timedelta import os
from datetime import UTC, datetime
from sqlalchemy import and_ import redis
from sqlalchemy import and_, text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from app.celery_app import celery_app from app.celery_app import celery_app
@ -21,10 +28,22 @@ from app.db import Notification
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Timeout in minutes - notifications without heartbeat for this long are marked as failed # Redis client for checking heartbeats
# Should be longer than HEARTBEAT_INTERVAL_SECONDS (30s) * a reasonable number of missed heartbeats _redis_client: redis.Redis | None = None
# 5 minutes = 10 missed heartbeats, which is a reasonable threshold
STALE_NOTIFICATION_TIMEOUT_MINUTES = 5
def get_redis_client() -> redis.Redis:
"""Get or create Redis client for heartbeat checking."""
global _redis_client
if _redis_client is None:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
_redis_client = redis.from_url(redis_url, decode_responses=True)
return _redis_client
def _get_heartbeat_key(notification_id: int) -> str:
"""Generate Redis key for notification heartbeat."""
return f"indexing:heartbeat:{notification_id}"
def get_celery_session_maker(): def get_celery_session_maker():
@ -45,9 +64,9 @@ def cleanup_stale_indexing_notifications_task():
This task finds notifications that: This task finds notifications that:
- Have type = 'connector_indexing' - Have type = 'connector_indexing'
- Have metadata.status = 'in_progress' - Have metadata.status = 'in_progress'
- Have updated_at older than STALE_NOTIFICATION_TIMEOUT_MINUTES - Do NOT have a corresponding Redis heartbeat key (meaning task crashed)
And marks them as failed with an appropriate error message. And marks them as failed with O(1) batch UPDATE.
""" """
import asyncio import asyncio
@ -61,84 +80,83 @@ def cleanup_stale_indexing_notifications_task():
async def _cleanup_stale_notifications(): async def _cleanup_stale_notifications():
"""Find and mark stale connector indexing notifications as failed.""" """Find and mark stale connector indexing notifications as failed.
Uses Redis TTL-based detection:
1. Find all in-progress notifications
2. Check which ones are missing their Redis heartbeat key
3. Mark those as failed with O(1) batch UPDATE using JSONB || operator
"""
async with get_celery_session_maker()() as session: async with get_celery_session_maker()() as session:
try: try:
# Calculate the cutoff time # Find all in-progress connector indexing notifications
cutoff_time = datetime.now(UTC) - timedelta(
minutes=STALE_NOTIFICATION_TIMEOUT_MINUTES
)
# Find stale notifications:
# - type = 'connector_indexing'
# - metadata->>'status' = 'in_progress'
# - updated_at < cutoff_time
result = await session.execute( result = await session.execute(
select(Notification).filter( select(Notification.id).where(
and_( and_(
Notification.type == "connector_indexing", Notification.type == "connector_indexing",
Notification.notification_metadata["status"].astext Notification.notification_metadata["status"].astext
== "in_progress", == "in_progress",
Notification.updated_at < cutoff_time,
) )
) )
) )
stale_notifications = result.scalars().all() in_progress_ids = [row[0] for row in result.fetchall()]
if not stale_notifications: if not in_progress_ids:
logger.debug("No stale connector indexing notifications found") logger.debug("No in-progress connector indexing notifications found")
return
# Check which ones are missing heartbeat keys in Redis
redis_client = get_redis_client()
stale_notification_ids = []
for notification_id in in_progress_ids:
heartbeat_key = _get_heartbeat_key(notification_id)
if not redis_client.exists(heartbeat_key):
stale_notification_ids.append(notification_id)
if not stale_notification_ids:
logger.debug(
f"All {len(in_progress_ids)} in-progress notifications have active Redis heartbeats"
)
return return
logger.warning( logger.warning(
f"Found {len(stale_notifications)} stale connector indexing notifications " f"Found {len(stale_notification_ids)} stale connector indexing notifications "
f"(no heartbeat for >{STALE_NOTIFICATION_TIMEOUT_MINUTES} minutes)" f"(no Redis heartbeat key): {stale_notification_ids}"
) )
# Mark each stale notification as failed # O(1) Batch UPDATE using JSONB || operator
for notification in stale_notifications: # This merges the update data into existing notification_metadata
try: # Also updates title and message for proper UI display
# Get current indexed count from metadata if available error_message = (
indexed_count = notification.notification_metadata.get( "Something went wrong while syncing your content. Please retry."
"indexed_count", 0 )
)
connector_name = notification.notification_metadata.get(
"connector_name", "Unknown"
)
# Calculate how long it's been stale update_data = {
stale_duration = datetime.now(UTC) - notification.updated_at "status": "failed",
stale_minutes = int(stale_duration.total_seconds() / 60) "completed_at": datetime.now(UTC).isoformat(),
"error_message": error_message,
"sync_stage": "failed",
}
# Update notification metadata await session.execute(
notification.notification_metadata["status"] = "failed" text("""
notification.notification_metadata["completed_at"] = datetime.now( UPDATE notifications
UTC SET metadata = metadata || CAST(:update_json AS jsonb),
).isoformat() title = 'Failed: ' || COALESCE(metadata->>'connector_name', 'Connector'),
notification.notification_metadata["error_message"] = ( message = :display_message
f"Indexing task appears to have crashed or timed out. " WHERE id = ANY(:ids)
f"No activity detected for {stale_minutes} minutes. " """),
f"Please try syncing again." {
) "update_json": json.dumps(update_data),
"display_message": f"{error_message}",
"ids": stale_notification_ids,
},
)
# Flag the JSONB column as modified for SQLAlchemy to detect the change
flag_modified(notification, "notification_metadata")
logger.info(
f"Marking notification {notification.id} for connector '{connector_name}' as failed "
f"(stale for {stale_minutes} minutes, indexed {indexed_count} items before failure)"
)
except Exception as e:
logger.error(
f"Error marking notification {notification.id} as failed: {e!s}",
exc_info=True,
)
continue
# Commit all changes
await session.commit() await session.commit()
logger.info( logger.info(
f"Successfully marked {len(stale_notifications)} stale notifications as failed" f"Successfully marked {len(stale_notification_ids)} stale notifications as failed (batch UPDATE)"
) )
except Exception as e: except Exception as e: