diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 6dfcbff46..7fd866f1c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -1,6 +1,8 @@ """Celery tasks for document processing.""" +import asyncio import logging +import os from uuid import UUID from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine @@ -17,6 +19,79 @@ from app.tasks.document_processors import ( logger = logging.getLogger(__name__) +# ===== Redis heartbeat for document processing tasks ===== +# Same mechanism as connector indexing heartbeats (search_source_connectors_routes.py). +# A background coroutine refreshes a Redis key every 60s with a 2-min TTL. +# If the Celery worker crashes, the coroutine dies, the key expires, and the +# stale_notification_cleanup_task detects the missing key and marks the +# notification + document as failed. +_doc_heartbeat_redis = None +HEARTBEAT_TTL_SECONDS = 120 # 2 minutes — same as connector indexing +HEARTBEAT_REFRESH_INTERVAL = 60 # Refresh every 60 seconds + + +def _get_doc_heartbeat_redis(): + """Get Redis client for document processing heartbeat.""" + import redis + + global _doc_heartbeat_redis + if _doc_heartbeat_redis is None: + redis_url = os.getenv( + "REDIS_APP_URL", + os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"), + ) + _doc_heartbeat_redis = redis.from_url(redis_url, decode_responses=True) + return _doc_heartbeat_redis + + +def _get_heartbeat_key(notification_id: int) -> str: + """Generate Redis key for document processing heartbeat. + + Uses same key pattern as connector indexing: indexing:heartbeat:{notification_id} + """ + return f"indexing:heartbeat:{notification_id}" + + +def _start_heartbeat(notification_id: int) -> None: + """Set initial Redis heartbeat key for a document processing task.""" + try: + key = _get_heartbeat_key(notification_id) + _get_doc_heartbeat_redis().setex(key, HEARTBEAT_TTL_SECONDS, "started") + except Exception as e: + logger.warning( + f"Failed to set initial heartbeat for notification {notification_id}: {e}" + ) + + +def _stop_heartbeat(notification_id: int) -> None: + """Delete Redis heartbeat key when task completes (success or failure).""" + try: + key = _get_heartbeat_key(notification_id) + _get_doc_heartbeat_redis().delete(key) + except Exception: + pass # Key will expire on its own + + +async def _run_heartbeat_loop(notification_id: int): + """Background coroutine that refreshes Redis heartbeat every 60 seconds. + + This keeps the heartbeat alive while the task is running. + When the task finishes, this coroutine is cancelled via heartbeat_task.cancel(). + When the worker crashes, this coroutine dies with it and the key expires. + """ + key = _get_heartbeat_key(notification_id) + try: + while True: + await asyncio.sleep(HEARTBEAT_REFRESH_INTERVAL) + try: + _get_doc_heartbeat_redis().setex(key, HEARTBEAT_TTL_SECONDS, "alive") + except Exception as e: + logger.warning( + f"Failed to refresh heartbeat for notification {notification_id}: {e}" + ) + except asyncio.CancelledError: + pass # Normal cancellation when task completes + def get_celery_session_maker(): """ @@ -44,8 +119,6 @@ def process_extension_document_task( search_space_id: ID of the search space user_id: ID of the user """ - import asyncio - # Create a new event loop for this task loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -196,8 +269,6 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st search_space_id: ID of the search space user_id: ID of the user """ - import asyncio - loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -226,6 +297,10 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str): ) ) + # Start Redis heartbeat for stale task detection + _start_heartbeat(notification.id) + heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) + log_entry = await task_logger.log_task_start( task_name="process_youtube_video", source="document_processor", @@ -243,7 +318,7 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str): ) result = await add_youtube_video_document( - session, url, search_space_id, user_id + session, url, search_space_id, user_id, notification=notification ) if result: @@ -307,6 +382,10 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str): logger.error(f"Error processing YouTube video: {e!s}") raise + finally: + # Stop heartbeat — key deleted on success, expires on crash + heartbeat_task.cancel() + _stop_heartbeat(notification.id) @celery_app.task(name="process_file_upload", bind=True) @@ -322,8 +401,6 @@ def process_file_upload_task( search_space_id: ID of the search space user_id: ID of the user """ - import asyncio - import os import traceback logger.info( @@ -370,8 +447,6 @@ async def _process_file_upload( file_path: str, filename: str, search_space_id: int, user_id: str ): """Process file upload with new session.""" - import os - from app.tasks.document_processors.file_processors import process_file_in_background logger.info(f"[_process_file_upload] Starting async processing for: {filename}") @@ -404,6 +479,10 @@ async def _process_file_upload( f"[_process_file_upload] Notification created with ID: {notification.id if notification else 'None'}" ) + # Start Redis heartbeat for stale task detection + _start_heartbeat(notification.id) + heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) + log_entry = await task_logger.log_task_start( task_name="process_file_upload", source="document_processor", @@ -535,6 +614,10 @@ async def _process_file_upload( ) logger.error(error_message) raise + finally: + # Stop heartbeat — key deleted on success, expires on crash + heartbeat_task.cancel() + _stop_heartbeat(notification.id) @celery_app.task(name="process_file_upload_with_document", bind=True) @@ -560,8 +643,6 @@ def process_file_upload_with_document_task( search_space_id: ID of the search space user_id: ID of the user """ - import asyncio - import os import traceback logger.info( @@ -640,8 +721,6 @@ async def _process_file_with_document( - Processes the file (parsing, embedding, chunking) - Updates document to 'ready' on success or 'failed' on error """ - import os - from app.db import Document, DocumentStatus from app.tasks.document_processors.base import get_current_timestamp from app.tasks.document_processors.file_processors import ( @@ -689,6 +768,19 @@ async def _process_file_with_document( ) ) + # Store document_id in notification metadata so cleanup task can find the document + if notification and notification.notification_metadata is not None: + notification.notification_metadata["document_id"] = document_id + from sqlalchemy.orm.attributes import flag_modified + + flag_modified(notification, "notification_metadata") + await session.commit() + await session.refresh(notification) + + # Start Redis heartbeat for stale task detection + _start_heartbeat(notification.id) + heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id)) + log_entry = await task_logger.log_task_start( task_name="process_file_upload_with_document", source="document_processor", @@ -822,6 +914,10 @@ async def _process_file_with_document( raise finally: + # Stop heartbeat — key deleted on success, expires on crash + heartbeat_task.cancel() + _stop_heartbeat(notification.id) + # Clean up temp file if os.path.exists(temp_path): try: @@ -856,8 +952,6 @@ def process_circleback_meeting_task( search_space_id: ID of the search space connector_id: ID of the Circleback connector (for deletion support) """ - import asyncio - loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -897,6 +991,7 @@ async def _process_circleback_meeting( # Create notification if user_id is available notification = None + heartbeat_task = None if user_id: notification = ( await NotificationService.document_processing.notify_processing_started( @@ -908,6 +1003,12 @@ async def _process_circleback_meeting( ) ) + # Start Redis heartbeat for stale task detection + _start_heartbeat(notification.id) + heartbeat_task = asyncio.create_task( + _run_heartbeat_loop(notification.id) + ) + log_entry = await task_logger.log_task_start( task_name="process_circleback_meeting", source="circleback_webhook", @@ -1000,3 +1101,9 @@ async def _process_circleback_meeting( logger.error(f"Error processing Circleback meeting: {e!s}") raise + finally: + # Stop heartbeat — key deleted on success, expires on crash + if heartbeat_task: + heartbeat_task.cancel() + if notification: + _stop_heartbeat(notification.id) diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py index ef3a30e43..aebe40b88 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py @@ -1,18 +1,25 @@ -"""Celery task to detect and mark stale connector indexing notifications as failed. +"""Celery task to detect and mark stale notifications as failed. This task runs periodically (every 5 minutes by default) to find notifications that are stuck in "in_progress" status but don't have an active Redis heartbeat key. -These are marked as "failed" to prevent the frontend from showing a perpetual "syncing" state. +These are marked as "failed" to prevent the frontend from showing a perpetual +"syncing" or "processing" state. -Additionally, it cleans up documents stuck in pending/processing state that belong -to connectors with stale notifications. +It handles two notification types: +1. **connector_indexing** — connector sync tasks (Google Calendar, etc.) +2. **document_processing** — manual file uploads, YouTube videos, etc. + +Additionally, it cleans up documents stuck in pending/processing state: +- For connectors: by connector_id +- For non-connector documents (FILE uploads, YouTube): by document_id from notification metadata Detection mechanism: -- Active indexing tasks set a Redis key with TTL (2 minutes) as a heartbeat -- If the task crashes, the Redis key expires automatically +- Active tasks set a Redis key with TTL (2 minutes) as a heartbeat +- A background coroutine refreshes the key every 60 seconds +- If the task/worker crashes, the Redis key expires automatically - This cleanup task checks for in-progress notifications without a Redis heartbeat key - Such notifications are marked as failed with O(1) batch UPDATE -- Documents with pending/processing status for those connectors are also marked as failed +- Associated documents are also marked as failed """ import contextlib @@ -36,8 +43,11 @@ logger = logging.getLogger(__name__) # Redis client for checking heartbeats _redis_client: redis.Redis | None = None -# Error message shown to users when sync is interrupted +# Error messages shown to users when tasks are interrupted STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry." +STALE_PROCESSING_ERROR_MESSAGE = ( + "Processing was interrupted unexpectedly. Please retry." +) def get_redis_client() -> redis.Redis: @@ -70,14 +80,13 @@ def get_celery_session_maker(): @celery_app.task(name="cleanup_stale_indexing_notifications") def cleanup_stale_indexing_notifications_task(): """ - Check for stale connector indexing notifications and mark them as failed. + Check for stale notifications and mark them as failed. - This task finds notifications that: - - Have type = 'connector_indexing' - - Have metadata.status = 'in_progress' - - Do NOT have a corresponding Redis heartbeat key (meaning task crashed) + Handles two notification types: + 1. connector_indexing — connector sync tasks + 2. document_processing — manual file uploads, YouTube videos, etc. - And marks them as failed with O(1) batch UPDATE. + Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task. Also marks associated pending/processing documents as failed. """ import asyncio @@ -87,6 +96,7 @@ def cleanup_stale_indexing_notifications_task(): try: loop.run_until_complete(_cleanup_stale_notifications()) + loop.run_until_complete(_cleanup_stale_document_processing_notifications()) finally: loop.close() @@ -269,3 +279,190 @@ async def _cleanup_stuck_documents(session, connector_ids: list[int]): exc_info=True, ) # Don't raise - let the notification cleanup continue even if document cleanup fails + + +# ===== Document Processing Cleanup (FILE uploads, YouTube, etc.) ===== + + +async def _cleanup_stale_document_processing_notifications(): + """Find and mark stale document processing notifications as failed. + + Same Redis heartbeat mechanism as connector indexing cleanup, but for + document_processing type notifications (FILE uploads, YouTube videos, etc.). + + For each stale notification that contains a document_id in its metadata, + the associated document is also marked as failed. + """ + async with get_celery_session_maker()() as session: + try: + # Find all in-progress document processing notifications + result = await session.execute( + select( + Notification.id, + Notification.notification_metadata, + ).where( + and_( + Notification.type == "document_processing", + Notification.notification_metadata["status"].astext + == "in_progress", + ) + ) + ) + in_progress_rows = result.fetchall() + + if not in_progress_rows: + logger.debug( + "No in-progress document processing notifications found" + ) + return + + # Check which ones are missing heartbeat keys in Redis + redis_client = get_redis_client() + stale_notification_ids = [] + stale_document_ids = [] + + for row in in_progress_rows: + notification_id = row[0] + metadata = row[1] # Full metadata dict + heartbeat_key = _get_heartbeat_key(notification_id) + if not redis_client.exists(heartbeat_key): + stale_notification_ids.append(notification_id) + # Extract document_id from metadata for document cleanup + if metadata and isinstance(metadata, dict): + doc_id = metadata.get("document_id") + if doc_id is not None: + with contextlib.suppress(ValueError, TypeError): + stale_document_ids.append(int(doc_id)) + + if not stale_notification_ids: + logger.debug( + f"All {len(in_progress_rows)} in-progress document processing " + "notifications have active Redis heartbeats" + ) + return + + logger.warning( + f"Found {len(stale_notification_ids)} stale document processing " + f"notifications (no Redis heartbeat): {stale_notification_ids}" + ) + + # O(1) Batch UPDATE: Mark stale notifications as failed + update_data = { + "status": "failed", + "completed_at": datetime.now(UTC).isoformat(), + "error_message": STALE_PROCESSING_ERROR_MESSAGE, + "processing_stage": "failed", + } + + await session.execute( + text(""" + UPDATE notifications + SET metadata = metadata || CAST(:update_json AS jsonb), + title = 'Failed: ' || COALESCE(metadata->>'document_name', 'Document'), + message = :display_message + WHERE id = ANY(:ids) + """), + { + "update_json": json.dumps(update_data), + "display_message": STALE_PROCESSING_ERROR_MESSAGE, + "ids": stale_notification_ids, + }, + ) + + logger.info( + f"Successfully marked {len(stale_notification_ids)} stale document " + "processing notifications as failed" + ) + + # Clean up stuck documents by document_id from notification metadata + if stale_document_ids: + await _cleanup_stuck_non_connector_documents( + session, stale_document_ids + ) + + await session.commit() + + except Exception as e: + logger.error( + f"Error cleaning up stale document processing notifications: {e!s}", + exc_info=True, + ) + await session.rollback() + + +async def _cleanup_stuck_non_connector_documents( + session, document_ids: list[int] +): + """ + Mark specific non-connector documents stuck in pending/processing as failed. + + These are documents (FILE uploads, YouTube, etc.) identified from stale + notification metadata. Only documents that are still in pending/processing + state are updated — already-completed documents are left untouched. + + Args: + session: Database session + document_ids: List of document IDs to check and potentially mark as failed + """ + if not document_ids: + return + + try: + # Find which of these documents are actually stuck + count_result = await session.execute( + select(Document.id).where( + and_( + Document.id.in_(document_ids), + or_( + Document.status["state"].astext == DocumentStatus.PENDING, + Document.status["state"].astext == DocumentStatus.PROCESSING, + ), + ) + ) + ) + stuck_doc_ids = [row[0] for row in count_result.fetchall()] + + if not stuck_doc_ids: + logger.debug( + f"No stuck non-connector documents found for IDs: {document_ids}" + ) + return + + logger.warning( + f"Found {len(stuck_doc_ids)} stuck non-connector documents " + f"(pending/processing): {stuck_doc_ids}" + ) + + failed_status = DocumentStatus.failed(STALE_PROCESSING_ERROR_MESSAGE) + + await session.execute( + text(""" + UPDATE documents + SET status = CAST(:failed_status AS jsonb), + updated_at = :now + WHERE id = ANY(:doc_ids) + AND ( + status->>'state' = :pending_state + OR status->>'state' = :processing_state + ) + """), + { + "failed_status": json.dumps(failed_status), + "now": datetime.now(UTC), + "doc_ids": stuck_doc_ids, + "pending_state": DocumentStatus.PENDING, + "processing_state": DocumentStatus.PROCESSING, + }, + ) + + logger.info( + f"Successfully marked {len(stuck_doc_ids)} stuck non-connector " + "documents as failed" + ) + + except Exception as e: + logger.error( + f"Error cleaning up stuck non-connector documents {document_ids}: {e!s}", + exc_info=True, + ) + # Don't raise — let the rest of the cleanup continue diff --git a/surfsense_backend/app/tasks/document_processors/youtube_processor.py b/surfsense_backend/app/tasks/document_processors/youtube_processor.py index 9dac6d554..427236fd3 100644 --- a/surfsense_backend/app/tasks/document_processors/youtube_processor.py +++ b/surfsense_backend/app/tasks/document_processors/youtube_processor.py @@ -61,7 +61,11 @@ def get_youtube_video_id(url: str) -> str | None: async def add_youtube_video_document( - session: AsyncSession, url: str, search_space_id: int, user_id: str + session: AsyncSession, + url: str, + search_space_id: int, + user_id: str, + notification=None, ) -> Document: """ Process a YouTube video URL, extract transcripts, and store as a document. @@ -75,6 +79,9 @@ async def add_youtube_video_document( url: YouTube video URL (supports standard, shortened, and embed formats) search_space_id: ID of the search space to add the document to user_id: ID of the user + notification: Optional notification object — if provided, the document_id + is stored in its metadata right after document creation so the stale + cleanup task can identify stuck documents. Returns: Document: The created document object @@ -182,6 +189,15 @@ async def add_youtube_video_document( await session.commit() # Document visible in UI now with pending status! is_new_document = True + # Store document_id in notification metadata so stale cleanup task + # can identify this document if the worker crashes. + if notification and notification.notification_metadata is not None: + from sqlalchemy.orm.attributes import flag_modified + + notification.notification_metadata["document_id"] = document.id + flag_modified(notification, "notification_metadata") + await session.commit() + logging.info(f"Created pending document for YouTube video {video_id}") # =======================================================================