mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-28 18:36:23 +02:00
feat: implement Redis heartbeat mechanism for document processing tasks and enhance stale notification cleanup
This commit is contained in:
parent
652eb6ece8
commit
72205ce11b
3 changed files with 350 additions and 30 deletions
|
|
@ -1,6 +1,8 @@
|
|||
"""Celery tasks for document processing."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
|
@ -17,6 +19,79 @@ from app.tasks.document_processors import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ===== Redis heartbeat for document processing tasks =====
|
||||
# Same mechanism as connector indexing heartbeats (search_source_connectors_routes.py).
|
||||
# A background coroutine refreshes a Redis key every 60s with a 2-min TTL.
|
||||
# If the Celery worker crashes, the coroutine dies, the key expires, and the
|
||||
# stale_notification_cleanup_task detects the missing key and marks the
|
||||
# notification + document as failed.
|
||||
_doc_heartbeat_redis = None
|
||||
HEARTBEAT_TTL_SECONDS = 120 # 2 minutes — same as connector indexing
|
||||
HEARTBEAT_REFRESH_INTERVAL = 60 # Refresh every 60 seconds
|
||||
|
||||
|
||||
def _get_doc_heartbeat_redis():
|
||||
"""Get Redis client for document processing heartbeat."""
|
||||
import redis
|
||||
|
||||
global _doc_heartbeat_redis
|
||||
if _doc_heartbeat_redis is None:
|
||||
redis_url = os.getenv(
|
||||
"REDIS_APP_URL",
|
||||
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
||||
)
|
||||
_doc_heartbeat_redis = redis.from_url(redis_url, decode_responses=True)
|
||||
return _doc_heartbeat_redis
|
||||
|
||||
|
||||
def _get_heartbeat_key(notification_id: int) -> str:
|
||||
"""Generate Redis key for document processing heartbeat.
|
||||
|
||||
Uses same key pattern as connector indexing: indexing:heartbeat:{notification_id}
|
||||
"""
|
||||
return f"indexing:heartbeat:{notification_id}"
|
||||
|
||||
|
||||
def _start_heartbeat(notification_id: int) -> None:
|
||||
"""Set initial Redis heartbeat key for a document processing task."""
|
||||
try:
|
||||
key = _get_heartbeat_key(notification_id)
|
||||
_get_doc_heartbeat_redis().setex(key, HEARTBEAT_TTL_SECONDS, "started")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to set initial heartbeat for notification {notification_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _stop_heartbeat(notification_id: int) -> None:
|
||||
"""Delete Redis heartbeat key when task completes (success or failure)."""
|
||||
try:
|
||||
key = _get_heartbeat_key(notification_id)
|
||||
_get_doc_heartbeat_redis().delete(key)
|
||||
except Exception:
|
||||
pass # Key will expire on its own
|
||||
|
||||
|
||||
async def _run_heartbeat_loop(notification_id: int):
|
||||
"""Background coroutine that refreshes Redis heartbeat every 60 seconds.
|
||||
|
||||
This keeps the heartbeat alive while the task is running.
|
||||
When the task finishes, this coroutine is cancelled via heartbeat_task.cancel().
|
||||
When the worker crashes, this coroutine dies with it and the key expires.
|
||||
"""
|
||||
key = _get_heartbeat_key(notification_id)
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(HEARTBEAT_REFRESH_INTERVAL)
|
||||
try:
|
||||
_get_doc_heartbeat_redis().setex(key, HEARTBEAT_TTL_SECONDS, "alive")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to refresh heartbeat for notification {notification_id}: {e}"
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass # Normal cancellation when task completes
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
|
|
@ -44,8 +119,6 @@ def process_extension_document_task(
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Create a new event loop for this task
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
|
@ -196,8 +269,6 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
|
@ -226,6 +297,10 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
|||
)
|
||||
)
|
||||
|
||||
# Start Redis heartbeat for stale task detection
|
||||
_start_heartbeat(notification.id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_youtube_video",
|
||||
source="document_processor",
|
||||
|
|
@ -243,7 +318,7 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
|||
)
|
||||
|
||||
result = await add_youtube_video_document(
|
||||
session, url, search_space_id, user_id
|
||||
session, url, search_space_id, user_id, notification=notification
|
||||
)
|
||||
|
||||
if result:
|
||||
|
|
@ -307,6 +382,10 @@ async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
|||
|
||||
logger.error(f"Error processing YouTube video: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
# Stop heartbeat — key deleted on success, expires on crash
|
||||
heartbeat_task.cancel()
|
||||
_stop_heartbeat(notification.id)
|
||||
|
||||
|
||||
@celery_app.task(name="process_file_upload", bind=True)
|
||||
|
|
@ -322,8 +401,6 @@ def process_file_upload_task(
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
|
||||
logger.info(
|
||||
|
|
@ -370,8 +447,6 @@ async def _process_file_upload(
|
|||
file_path: str, filename: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Process file upload with new session."""
|
||||
import os
|
||||
|
||||
from app.tasks.document_processors.file_processors import process_file_in_background
|
||||
|
||||
logger.info(f"[_process_file_upload] Starting async processing for: {filename}")
|
||||
|
|
@ -404,6 +479,10 @@ async def _process_file_upload(
|
|||
f"[_process_file_upload] Notification created with ID: {notification.id if notification else 'None'}"
|
||||
)
|
||||
|
||||
# Start Redis heartbeat for stale task detection
|
||||
_start_heartbeat(notification.id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_file_upload",
|
||||
source="document_processor",
|
||||
|
|
@ -535,6 +614,10 @@ async def _process_file_upload(
|
|||
)
|
||||
logger.error(error_message)
|
||||
raise
|
||||
finally:
|
||||
# Stop heartbeat — key deleted on success, expires on crash
|
||||
heartbeat_task.cancel()
|
||||
_stop_heartbeat(notification.id)
|
||||
|
||||
|
||||
@celery_app.task(name="process_file_upload_with_document", bind=True)
|
||||
|
|
@ -560,8 +643,6 @@ def process_file_upload_with_document_task(
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
|
||||
logger.info(
|
||||
|
|
@ -640,8 +721,6 @@ async def _process_file_with_document(
|
|||
- Processes the file (parsing, embedding, chunking)
|
||||
- Updates document to 'ready' on success or 'failed' on error
|
||||
"""
|
||||
import os
|
||||
|
||||
from app.db import Document, DocumentStatus
|
||||
from app.tasks.document_processors.base import get_current_timestamp
|
||||
from app.tasks.document_processors.file_processors import (
|
||||
|
|
@ -689,6 +768,19 @@ async def _process_file_with_document(
|
|||
)
|
||||
)
|
||||
|
||||
# Store document_id in notification metadata so cleanup task can find the document
|
||||
if notification and notification.notification_metadata is not None:
|
||||
notification.notification_metadata["document_id"] = document_id
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
flag_modified(notification, "notification_metadata")
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
|
||||
# Start Redis heartbeat for stale task detection
|
||||
_start_heartbeat(notification.id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification.id))
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_file_upload_with_document",
|
||||
source="document_processor",
|
||||
|
|
@ -822,6 +914,10 @@ async def _process_file_with_document(
|
|||
raise
|
||||
|
||||
finally:
|
||||
# Stop heartbeat — key deleted on success, expires on crash
|
||||
heartbeat_task.cancel()
|
||||
_stop_heartbeat(notification.id)
|
||||
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
|
|
@ -856,8 +952,6 @@ def process_circleback_meeting_task(
|
|||
search_space_id: ID of the search space
|
||||
connector_id: ID of the Circleback connector (for deletion support)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
|
@ -897,6 +991,7 @@ async def _process_circleback_meeting(
|
|||
|
||||
# Create notification if user_id is available
|
||||
notification = None
|
||||
heartbeat_task = None
|
||||
if user_id:
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
|
|
@ -908,6 +1003,12 @@ async def _process_circleback_meeting(
|
|||
)
|
||||
)
|
||||
|
||||
# Start Redis heartbeat for stale task detection
|
||||
_start_heartbeat(notification.id)
|
||||
heartbeat_task = asyncio.create_task(
|
||||
_run_heartbeat_loop(notification.id)
|
||||
)
|
||||
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="process_circleback_meeting",
|
||||
source="circleback_webhook",
|
||||
|
|
@ -1000,3 +1101,9 @@ async def _process_circleback_meeting(
|
|||
|
||||
logger.error(f"Error processing Circleback meeting: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
# Stop heartbeat — key deleted on success, expires on crash
|
||||
if heartbeat_task:
|
||||
heartbeat_task.cancel()
|
||||
if notification:
|
||||
_stop_heartbeat(notification.id)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,25 @@
|
|||
"""Celery task to detect and mark stale connector indexing notifications as failed.
|
||||
"""Celery task to detect and mark stale notifications as failed.
|
||||
|
||||
This task runs periodically (every 5 minutes by default) to find notifications
|
||||
that are stuck in "in_progress" status but don't have an active Redis heartbeat key.
|
||||
These are marked as "failed" to prevent the frontend from showing a perpetual "syncing" state.
|
||||
These are marked as "failed" to prevent the frontend from showing a perpetual
|
||||
"syncing" or "processing" state.
|
||||
|
||||
Additionally, it cleans up documents stuck in pending/processing state that belong
|
||||
to connectors with stale notifications.
|
||||
It handles two notification types:
|
||||
1. **connector_indexing** — connector sync tasks (Google Calendar, etc.)
|
||||
2. **document_processing** — manual file uploads, YouTube videos, etc.
|
||||
|
||||
Additionally, it cleans up documents stuck in pending/processing state:
|
||||
- For connectors: by connector_id
|
||||
- For non-connector documents (FILE uploads, YouTube): by document_id from notification metadata
|
||||
|
||||
Detection mechanism:
|
||||
- Active indexing tasks set a Redis key with TTL (2 minutes) as a heartbeat
|
||||
- If the task crashes, the Redis key expires automatically
|
||||
- Active tasks set a Redis key with TTL (2 minutes) as a heartbeat
|
||||
- A background coroutine refreshes the key every 60 seconds
|
||||
- If the task/worker crashes, the Redis key expires automatically
|
||||
- This cleanup task checks for in-progress notifications without a Redis heartbeat key
|
||||
- Such notifications are marked as failed with O(1) batch UPDATE
|
||||
- Documents with pending/processing status for those connectors are also marked as failed
|
||||
- Associated documents are also marked as failed
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
|
|
@ -36,8 +43,11 @@ logger = logging.getLogger(__name__)
|
|||
# Redis client for checking heartbeats
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
# Error message shown to users when sync is interrupted
|
||||
# Error messages shown to users when tasks are interrupted
|
||||
STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry."
|
||||
STALE_PROCESSING_ERROR_MESSAGE = (
|
||||
"Processing was interrupted unexpectedly. Please retry."
|
||||
)
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
|
|
@ -70,14 +80,13 @@ def get_celery_session_maker():
|
|||
@celery_app.task(name="cleanup_stale_indexing_notifications")
|
||||
def cleanup_stale_indexing_notifications_task():
|
||||
"""
|
||||
Check for stale connector indexing notifications and mark them as failed.
|
||||
Check for stale notifications and mark them as failed.
|
||||
|
||||
This task finds notifications that:
|
||||
- Have type = 'connector_indexing'
|
||||
- Have metadata.status = 'in_progress'
|
||||
- Do NOT have a corresponding Redis heartbeat key (meaning task crashed)
|
||||
Handles two notification types:
|
||||
1. connector_indexing — connector sync tasks
|
||||
2. document_processing — manual file uploads, YouTube videos, etc.
|
||||
|
||||
And marks them as failed with O(1) batch UPDATE.
|
||||
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
|
||||
Also marks associated pending/processing documents as failed.
|
||||
"""
|
||||
import asyncio
|
||||
|
|
@ -87,6 +96,7 @@ def cleanup_stale_indexing_notifications_task():
|
|||
|
||||
try:
|
||||
loop.run_until_complete(_cleanup_stale_notifications())
|
||||
loop.run_until_complete(_cleanup_stale_document_processing_notifications())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
|
@ -269,3 +279,190 @@ async def _cleanup_stuck_documents(session, connector_ids: list[int]):
|
|||
exc_info=True,
|
||||
)
|
||||
# Don't raise - let the notification cleanup continue even if document cleanup fails
|
||||
|
||||
|
||||
# ===== Document Processing Cleanup (FILE uploads, YouTube, etc.) =====
|
||||
|
||||
|
||||
async def _cleanup_stale_document_processing_notifications():
|
||||
"""Find and mark stale document processing notifications as failed.
|
||||
|
||||
Same Redis heartbeat mechanism as connector indexing cleanup, but for
|
||||
document_processing type notifications (FILE uploads, YouTube videos, etc.).
|
||||
|
||||
For each stale notification that contains a document_id in its metadata,
|
||||
the associated document is also marked as failed.
|
||||
"""
|
||||
async with get_celery_session_maker()() as session:
|
||||
try:
|
||||
# Find all in-progress document processing notifications
|
||||
result = await session.execute(
|
||||
select(
|
||||
Notification.id,
|
||||
Notification.notification_metadata,
|
||||
).where(
|
||||
and_(
|
||||
Notification.type == "document_processing",
|
||||
Notification.notification_metadata["status"].astext
|
||||
== "in_progress",
|
||||
)
|
||||
)
|
||||
)
|
||||
in_progress_rows = result.fetchall()
|
||||
|
||||
if not in_progress_rows:
|
||||
logger.debug(
|
||||
"No in-progress document processing notifications found"
|
||||
)
|
||||
return
|
||||
|
||||
# Check which ones are missing heartbeat keys in Redis
|
||||
redis_client = get_redis_client()
|
||||
stale_notification_ids = []
|
||||
stale_document_ids = []
|
||||
|
||||
for row in in_progress_rows:
|
||||
notification_id = row[0]
|
||||
metadata = row[1] # Full metadata dict
|
||||
heartbeat_key = _get_heartbeat_key(notification_id)
|
||||
if not redis_client.exists(heartbeat_key):
|
||||
stale_notification_ids.append(notification_id)
|
||||
# Extract document_id from metadata for document cleanup
|
||||
if metadata and isinstance(metadata, dict):
|
||||
doc_id = metadata.get("document_id")
|
||||
if doc_id is not None:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
stale_document_ids.append(int(doc_id))
|
||||
|
||||
if not stale_notification_ids:
|
||||
logger.debug(
|
||||
f"All {len(in_progress_rows)} in-progress document processing "
|
||||
"notifications have active Redis heartbeats"
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"Found {len(stale_notification_ids)} stale document processing "
|
||||
f"notifications (no Redis heartbeat): {stale_notification_ids}"
|
||||
)
|
||||
|
||||
# O(1) Batch UPDATE: Mark stale notifications as failed
|
||||
update_data = {
|
||||
"status": "failed",
|
||||
"completed_at": datetime.now(UTC).isoformat(),
|
||||
"error_message": STALE_PROCESSING_ERROR_MESSAGE,
|
||||
"processing_stage": "failed",
|
||||
}
|
||||
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE notifications
|
||||
SET metadata = metadata || CAST(:update_json AS jsonb),
|
||||
title = 'Failed: ' || COALESCE(metadata->>'document_name', 'Document'),
|
||||
message = :display_message
|
||||
WHERE id = ANY(:ids)
|
||||
"""),
|
||||
{
|
||||
"update_json": json.dumps(update_data),
|
||||
"display_message": STALE_PROCESSING_ERROR_MESSAGE,
|
||||
"ids": stale_notification_ids,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully marked {len(stale_notification_ids)} stale document "
|
||||
"processing notifications as failed"
|
||||
)
|
||||
|
||||
# Clean up stuck documents by document_id from notification metadata
|
||||
if stale_document_ids:
|
||||
await _cleanup_stuck_non_connector_documents(
|
||||
session, stale_document_ids
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error cleaning up stale document processing notifications: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
await session.rollback()
|
||||
|
||||
|
||||
async def _cleanup_stuck_non_connector_documents(
|
||||
session, document_ids: list[int]
|
||||
):
|
||||
"""
|
||||
Mark specific non-connector documents stuck in pending/processing as failed.
|
||||
|
||||
These are documents (FILE uploads, YouTube, etc.) identified from stale
|
||||
notification metadata. Only documents that are still in pending/processing
|
||||
state are updated — already-completed documents are left untouched.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
document_ids: List of document IDs to check and potentially mark as failed
|
||||
"""
|
||||
if not document_ids:
|
||||
return
|
||||
|
||||
try:
|
||||
# Find which of these documents are actually stuck
|
||||
count_result = await session.execute(
|
||||
select(Document.id).where(
|
||||
and_(
|
||||
Document.id.in_(document_ids),
|
||||
or_(
|
||||
Document.status["state"].astext == DocumentStatus.PENDING,
|
||||
Document.status["state"].astext == DocumentStatus.PROCESSING,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
stuck_doc_ids = [row[0] for row in count_result.fetchall()]
|
||||
|
||||
if not stuck_doc_ids:
|
||||
logger.debug(
|
||||
f"No stuck non-connector documents found for IDs: {document_ids}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"Found {len(stuck_doc_ids)} stuck non-connector documents "
|
||||
f"(pending/processing): {stuck_doc_ids}"
|
||||
)
|
||||
|
||||
failed_status = DocumentStatus.failed(STALE_PROCESSING_ERROR_MESSAGE)
|
||||
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE documents
|
||||
SET status = CAST(:failed_status AS jsonb),
|
||||
updated_at = :now
|
||||
WHERE id = ANY(:doc_ids)
|
||||
AND (
|
||||
status->>'state' = :pending_state
|
||||
OR status->>'state' = :processing_state
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"failed_status": json.dumps(failed_status),
|
||||
"now": datetime.now(UTC),
|
||||
"doc_ids": stuck_doc_ids,
|
||||
"pending_state": DocumentStatus.PENDING,
|
||||
"processing_state": DocumentStatus.PROCESSING,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully marked {len(stuck_doc_ids)} stuck non-connector "
|
||||
"documents as failed"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error cleaning up stuck non-connector documents {document_ids}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Don't raise — let the rest of the cleanup continue
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue