diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index c4facc84d..9aa4b0b34 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -9,6 +9,8 @@ CELERY_TASK_DEFAULT_QUEUE=surfsense # Redis for app-level features (heartbeats, podcast markers) # Defaults to CELERY_BROKER_URL when not set REDIS_APP_URL=redis://localhost:6379/0 +# Optional: TTL in seconds for connector indexing lock key +# CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800 #Electric(for migrations only) ELECTRIC_DB_USER=electric diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index a11e4ac38..5f31b4f5c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -8,6 +8,7 @@ This module provides: - Tool factory for creating search_knowledge_base tools """ +import asyncio import json from datetime import datetime from typing import Any @@ -16,6 +17,7 @@ from langchain_core.tools import StructuredTool from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession +from app.db import async_session_maker from app.services.connector_service import ConnectorService # ============================================================================= @@ -333,7 +335,7 @@ async def search_knowledge_base_async( Returns: Formatted string with search results """ - all_documents = [] + all_documents: list[dict[str, Any]] = [] # Resolve date range (default last 2 years) from app.agents.new_chat.utils import resolve_date_range @@ -345,323 +347,132 @@ async def search_knowledge_base_async( connectors = _normalize_connectors(connectors_to_search, available_connectors) - for connector in connectors: + connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { + "YOUTUBE_VIDEO": ("search_youtube", True, True, {}), + "EXTENSION": ("search_extension", True, True, {}), + "CRAWLED_URL": ("search_crawled_urls", True, True, {}), + "FILE": ("search_files", True, True, {}), + "SLACK_CONNECTOR": ("search_slack", True, True, {}), + "TEAMS_CONNECTOR": ("search_teams", True, True, {}), + "NOTION_CONNECTOR": ("search_notion", True, True, {}), + "GITHUB_CONNECTOR": ("search_github", True, True, {}), + "LINEAR_CONNECTOR": ("search_linear", True, True, {}), + "TAVILY_API": ("search_tavily", False, True, {}), + "SEARXNG_API": ("search_searxng", False, True, {}), + "LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}), + "BAIDU_SEARCH_API": ("search_baidu", False, True, {}), + "DISCORD_CONNECTOR": ("search_discord", True, True, {}), + "JIRA_CONNECTOR": ("search_jira", True, True, {}), + "GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}), + "AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}), + "GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}), + "GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}), + "CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}), + "CLICKUP_CONNECTOR": ("search_clickup", True, True, {}), + "LUMA_CONNECTOR": ("search_luma", True, True, {}), + "ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}), + "NOTE": ("search_notes", True, True, {}), + "BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}), + "CIRCLEBACK": ("search_circleback", True, True, {}), + "OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}), + # Composio connectors + "COMPOSIO_GOOGLE_DRIVE_CONNECTOR": ( + "search_composio_google_drive", + True, + True, + {}, + ), + "COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}), + "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": ( + "search_composio_google_calendar", + True, + True, + {}, + ), + } + + # Keep a conservative cap to avoid overloading DB/external services. + max_parallel_searches = 4 + semaphore = asyncio.Semaphore(max_parallel_searches) + + async def _search_one_connector(connector: str) -> list[dict[str, Any]]: + spec = connector_specs.get(connector) + if spec is None: + return [] + + method_name, includes_date_range, includes_top_k, extra_kwargs = spec + kwargs: dict[str, Any] = { + "user_query": query, + "search_space_id": search_space_id, + **extra_kwargs, + } + if includes_top_k: + kwargs["top_k"] = top_k + if includes_date_range: + kwargs["start_date"] = resolved_start_date + kwargs["end_date"] = resolved_end_date + try: - if connector == "YOUTUBE_VIDEO": - _, chunks = await connector_service.search_youtube( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "EXTENSION": - _, chunks = await connector_service.search_extension( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "CRAWLED_URL": - _, chunks = await connector_service.search_crawled_urls( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "FILE": - _, chunks = await connector_service.search_files( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "SLACK_CONNECTOR": - _, chunks = await connector_service.search_slack( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "TEAMS_CONNECTOR": - _, chunks = await connector_service.search_teams( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "NOTION_CONNECTOR": - _, chunks = await connector_service.search_notion( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "GITHUB_CONNECTOR": - _, chunks = await connector_service.search_github( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "LINEAR_CONNECTOR": - _, chunks = await connector_service.search_linear( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "TAVILY_API": - _, chunks = await connector_service.search_tavily( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - ) - all_documents.extend(chunks) - - elif connector == "SEARXNG_API": - _, chunks = await connector_service.search_searxng( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - ) - all_documents.extend(chunks) - - elif connector == "LINKUP_API": - # Keep behavior aligned with researcher: default "standard" - _, chunks = await connector_service.search_linkup( - user_query=query, - search_space_id=search_space_id, - mode="standard", - ) - all_documents.extend(chunks) - - elif connector == "BAIDU_SEARCH_API": - _, chunks = await connector_service.search_baidu( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - ) - all_documents.extend(chunks) - - elif connector == "DISCORD_CONNECTOR": - _, chunks = await connector_service.search_discord( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "JIRA_CONNECTOR": - _, chunks = await connector_service.search_jira( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "GOOGLE_CALENDAR_CONNECTOR": - _, chunks = await connector_service.search_google_calendar( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "AIRTABLE_CONNECTOR": - _, chunks = await connector_service.search_airtable( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "GOOGLE_GMAIL_CONNECTOR": - _, chunks = await connector_service.search_google_gmail( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "GOOGLE_DRIVE_FILE": - _, chunks = await connector_service.search_google_drive( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "CONFLUENCE_CONNECTOR": - _, chunks = await connector_service.search_confluence( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "CLICKUP_CONNECTOR": - _, chunks = await connector_service.search_clickup( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "LUMA_CONNECTOR": - _, chunks = await connector_service.search_luma( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "ELASTICSEARCH_CONNECTOR": - _, chunks = await connector_service.search_elasticsearch( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "NOTE": - _, chunks = await connector_service.search_notes( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "BOOKSTACK_CONNECTOR": - _, chunks = await connector_service.search_bookstack( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "CIRCLEBACK": - _, chunks = await connector_service.search_circleback( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "OBSIDIAN_CONNECTOR": - _, chunks = await connector_service.search_obsidian( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - # ========================================================= - # Composio Connectors - # ========================================================= - elif connector == "COMPOSIO_GOOGLE_DRIVE_CONNECTOR": - _, chunks = await connector_service.search_composio_google_drive( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "COMPOSIO_GMAIL_CONNECTOR": - _, chunks = await connector_service.search_composio_gmail( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - - elif connector == "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": - _, chunks = await connector_service.search_composio_google_calendar( - user_query=query, - search_space_id=search_space_id, - top_k=top_k, - start_date=resolved_start_date, - end_date=resolved_end_date, - ) - all_documents.extend(chunks) - + async with semaphore: + # Use isolated session per connector. Shared AsyncSession cannot safely + # run concurrent DB operations. + async with async_session_maker() as isolated_session: + isolated_connector_service = ConnectorService( + isolated_session, search_space_id + ) + connector_method = getattr(isolated_connector_service, method_name) + _, chunks = await connector_method(**kwargs) + return chunks except Exception as e: print(f"Error searching connector {connector}: {e}") - continue + return [] - # Deduplicate by content hash + connector_results = await asyncio.gather( + *[_search_one_connector(connector) for connector in connectors] + ) + for chunks in connector_results: + all_documents.extend(chunks) + + # Deduplicate primarily by document ID. Only fall back to content hashing + # when a document has no ID. seen_doc_ids: set[Any] = set() - seen_hashes: set[int] = set() + seen_content_hashes: set[int] = set() deduplicated: list[dict[str, Any]] = [] + + def _content_fingerprint(document: dict[str, Any]) -> int | None: + chunks = document.get("chunks") + if isinstance(chunks, list): + chunk_texts = [] + for chunk in chunks: + if not isinstance(chunk, dict): + continue + chunk_content = (chunk.get("content") or "").strip() + if chunk_content: + chunk_texts.append(chunk_content) + if chunk_texts: + return hash("||".join(chunk_texts)) + + flat_content = (document.get("content") or "").strip() + if flat_content: + return hash(flat_content) + return None + for doc in all_documents: doc_id = (doc.get("document", {}) or {}).get("id") - content = (doc.get("content", "") or "").strip() - content_hash = hash(content) - if (doc_id and doc_id in seen_doc_ids) or content_hash in seen_hashes: + if doc_id is not None: + if doc_id in seen_doc_ids: + continue + seen_doc_ids.add(doc_id) + deduplicated.append(doc) continue - if doc_id: - seen_doc_ids.add(doc_id) - seen_hashes.add(content_hash) + content_hash = _content_fingerprint(doc) + if content_hash is not None: + if content_hash in seen_content_hashes: + continue + seen_content_hashes.add(content_hash) + deduplicated.append(doc) return format_documents_for_context(deduplicated) diff --git a/surfsense_backend/app/agents/new_chat/tools/podcast.py b/surfsense_backend/app/agents/new_chat/tools/podcast.py index e6412f4f2..8ac537f9a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/podcast.py +++ b/surfsense_backend/app/agents/new_chat/tools/podcast.py @@ -11,21 +11,18 @@ Duplicate request prevention: - Returns a friendly message if a podcast is already being generated """ -import os from typing import Any import redis from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.config import config from app.db import Podcast, PodcastStatus # Redis connection for tracking active podcast tasks # Defaults to the Celery broker when REDIS_APP_URL is not set -REDIS_URL = os.getenv( - "REDIS_APP_URL", - os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"), -) +REDIS_URL = config.REDIS_APP_URL _redis_client: redis.Redis | None = None diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index e102c414d..68c65a818 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -213,6 +213,17 @@ class Config: # Database DATABASE_URL = os.getenv("DATABASE_URL") + # Celery / Redis + CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") + CELERY_RESULT_BACKEND = os.getenv( + "CELERY_RESULT_BACKEND", "redis://localhost:6379/0" + ) + CELERY_TASK_DEFAULT_QUEUE = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense") + REDIS_APP_URL = os.getenv("REDIS_APP_URL", CELERY_BROKER_URL) + CONNECTOR_INDEXING_LOCK_TTL_SECONDS = int( + os.getenv("CONNECTOR_INDEXING_LOCK_TTL_SECONDS", str(8 * 60 * 60)) + ) + NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL") # Backend URL to override the http to https in the OAuth redirect URI BACKEND_URL = os.getenv("BACKEND_URL") diff --git a/surfsense_backend/app/connectors/notion_history.py b/surfsense_backend/app/connectors/notion_history.py index ff8478905..525b0b4c3 100644 --- a/surfsense_backend/app/connectors/notion_history.py +++ b/surfsense_backend/app/connectors/notion_history.py @@ -27,6 +27,12 @@ T = TypeVar("T") MAX_RETRIES = 5 BASE_RETRY_DELAY = 1.0 # seconds MAX_RETRY_DELAY = 60.0 # seconds (Notion's max request timeout) +MAX_RATE_LIMIT_WAIT_SECONDS = float( + getattr(config, "NOTION_MAX_RETRY_AFTER_SECONDS", 30.0) +) +MAX_TOTAL_RETRY_WAIT_SECONDS = float( + getattr(config, "NOTION_MAX_TOTAL_RETRY_WAIT_SECONDS", 120.0) +) # Type alias for retry callback function # Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) -> None @@ -292,6 +298,7 @@ class NotionHistoryConnector: """ last_exception: APIResponseError | None = None retry_delay = BASE_RETRY_DELAY + total_wait_time = 0.0 for attempt in range(MAX_RETRIES): try: @@ -325,6 +332,15 @@ class NotionHistoryConnector: wait_time = retry_delay else: wait_time = retry_delay + + # Avoid very long worker sleeps from external Retry-After values. + if wait_time > MAX_RATE_LIMIT_WAIT_SECONDS: + logger.warning( + f"Notion Retry-After ({wait_time}s) exceeds cap " + f"({MAX_RATE_LIMIT_WAIT_SECONDS}s). Clamping wait time." + ) + wait_time = MAX_RATE_LIMIT_WAIT_SECONDS + logger.warning( f"Notion API rate limited (429). " f"Waiting {wait_time}s. Attempt {attempt + 1}/{MAX_RETRIES}" @@ -348,6 +364,14 @@ class NotionHistoryConnector: # Notify about retry via callback (for user notifications) # Call before sleeping so user sees the message while we wait + if total_wait_time + wait_time > MAX_TOTAL_RETRY_WAIT_SECONDS: + logger.error( + "Notion API retry budget exceeded " + f"({total_wait_time + wait_time:.1f}s > " + f"{MAX_TOTAL_RETRY_WAIT_SECONDS:.1f}s). Failing fast." + ) + raise + if on_retry: try: await on_retry( @@ -362,6 +386,7 @@ class NotionHistoryConnector: # Wait before retrying await asyncio.sleep(wait_time) + total_wait_time += wait_time # Exponential backoff for next attempt retry_delay = min(retry_delay * 2, MAX_RETRY_DELAY) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 747e02834..02737c146 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -19,7 +19,6 @@ Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search spa """ import logging -import os from datetime import UTC, datetime, timedelta from typing import Any @@ -76,6 +75,10 @@ from app.utils.periodic_scheduler import ( update_periodic_schedule, ) from app.utils.rbac import check_permission +from app.utils.indexing_locks import ( + acquire_connector_indexing_lock, + release_connector_indexing_lock, +) # Set up logging logger = logging.getLogger(__name__) @@ -91,11 +94,9 @@ def get_heartbeat_redis_client() -> redis.Redis: """Get or create Redis client for heartbeat tracking.""" global _heartbeat_redis_client if _heartbeat_redis_client is None: - redis_url = os.getenv( - "REDIS_APP_URL", - os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"), + _heartbeat_redis_client = redis.from_url( + config.REDIS_APP_URL, decode_responses=True ) - _heartbeat_redis_client = redis.from_url(redis_url, decode_responses=True) return _heartbeat_redis_client @@ -1229,10 +1230,19 @@ async def _run_indexing_with_notifications( from celery.exceptions import SoftTimeLimitExceeded notification = None + connector_lock_acquired = False # Track indexed count for retry notifications and heartbeat current_indexed_count = 0 try: + connector_lock_acquired = acquire_connector_indexing_lock(connector_id) + if not connector_lock_acquired: + logger.info( + f"Skipping indexing for connector {connector_id} " + "(another worker already holds Redis connector lock)" + ) + return + # Get connector info for notification connector_result = await session.execute( select(SearchSourceConnector).where( @@ -1558,6 +1568,11 @@ async def _run_indexing_with_notifications( get_heartbeat_redis_client().delete(heartbeat_key) except Exception: pass # Ignore cleanup errors - key will expire anyway + if connector_lock_acquired: + try: + release_connector_indexing_lock(connector_id) + except Exception: + pass # Lock has TTL; safe to ignore cleanup failures async def run_notion_indexing_with_new_session( diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 81c5dbba2..859b15018 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -36,11 +36,9 @@ def _get_doc_heartbeat_redis(): global _doc_heartbeat_redis if _doc_heartbeat_redis is None: - redis_url = os.getenv( - "REDIS_APP_URL", - os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"), + _doc_heartbeat_redis = redis.from_url( + config.REDIS_APP_URL, decode_responses=True ) - _doc_heartbeat_redis = redis.from_url(redis_url, decode_responses=True) return _doc_heartbeat_redis @@ -1104,4 +1102,4 @@ async def _process_circleback_meeting( if heartbeat_task: heartbeat_task.cancel() if notification: - _stop_heartbeat(notification.id) + _stop_heartbeat(notification.id) \ No newline at end of file diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 14df83508..973e7e750 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -46,16 +46,10 @@ def get_celery_session_maker(): def _clear_generating_podcast(search_space_id: int) -> None: """Clear the generating podcast marker from Redis when task completes.""" - import os - import redis try: - redis_url = os.getenv( - "REDIS_APP_URL", - os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"), - ) - client = redis.from_url(redis_url, decode_responses=True) + client = redis.from_url(config.REDIS_APP_URL, decode_responses=True) key = f"podcast:generating:{search_space_id}" client.delete(key) logger.info( diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index b33e25170..80d271aaa 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -9,7 +9,8 @@ from sqlalchemy.pool import NullPool from app.celery_app import celery_app from app.config import config -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType +from app.utils.indexing_locks import is_connector_indexing_locked logger = logging.getLogger(__name__) @@ -107,6 +108,32 @@ async def _check_and_trigger_schedules(): # Trigger indexing for each due connector for connector in due_connectors: + # Primary guard: Redis lock indicates a task is currently running. + if is_connector_indexing_locked(connector.id): + logger.info( + f"Skipping periodic indexing for connector {connector.id} " + "(Redis lock indicates indexing is already in progress)" + ) + continue + + # Skip scheduling if a sync for this connector is already in progress. + # This prevents duplicate tasks from piling up under slow/rate-limited providers. + in_progress_result = await session.execute( + select(Notification.id).where( + Notification.type == "connector_indexing", + Notification.notification_metadata["connector_id"].astext + == str(connector.id), + Notification.notification_metadata["status"].astext + == "in_progress", + ) + ) + if in_progress_result.first(): + logger.info( + f"Skipping periodic indexing for connector {connector.id} " + "(already has in-progress indexing notification)" + ) + continue + task = task_map.get(connector.connector_type) if task: logger.info( diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py index f3bbddee0..c2c82dd2c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py @@ -25,7 +25,6 @@ Detection mechanism: import contextlib import json import logging -import os from datetime import UTC, datetime import redis @@ -52,11 +51,7 @@ def get_redis_client() -> redis.Redis: """Get or create Redis client for heartbeat checking.""" global _redis_client if _redis_client is None: - redis_url = os.getenv( - "REDIS_APP_URL", - os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"), - ) - _redis_client = redis.from_url(redis_url, decode_responses=True) + _redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True) return _redis_client diff --git a/surfsense_backend/app/tasks/connector_indexers/base.py b/surfsense_backend/app/tasks/connector_indexers/base.py index 6a1226230..139aed1d3 100644 --- a/surfsense_backend/app/tasks/connector_indexers/base.py +++ b/surfsense_backend/app/tasks/connector_indexers/base.py @@ -52,10 +52,22 @@ def safe_set_chunks(document: Document, chunks: list) -> None: # Instead of: document.chunks = chunks (DANGEROUS!) safe_set_chunks(document, chunks) # Always safe """ + from sqlalchemy.orm import object_session from sqlalchemy.orm.attributes import set_committed_value + # Keep relationship assignment lazy-load-safe. set_committed_value(document, "chunks", chunks) + # Ensure chunk rows are actually persisted. + # set_committed_value bypasses normal unit-of-work tracking, so we need to + # explicitly attach chunk objects to the current session. + session = object_session(document) + if session is not None: + if document.id is not None: + for chunk in chunks: + chunk.document_id = document.id + session.add_all(chunks) + def parse_date_flexible(date_str: str) -> datetime: """ diff --git a/surfsense_backend/app/tasks/document_processors/base.py b/surfsense_backend/app/tasks/document_processors/base.py index 2047ec63d..2edc48e91 100644 --- a/surfsense_backend/app/tasks/document_processors/base.py +++ b/surfsense_backend/app/tasks/document_processors/base.py @@ -38,10 +38,22 @@ def safe_set_chunks(document: Document, chunks: list) -> None: # Instead of: document.chunks = chunks (DANGEROUS!) safe_set_chunks(document, chunks) # Always safe """ + from sqlalchemy.orm import object_session from sqlalchemy.orm.attributes import set_committed_value + # Keep relationship assignment lazy-load-safe. set_committed_value(document, "chunks", chunks) + # Ensure chunk rows are actually persisted. + # set_committed_value bypasses normal unit-of-work tracking, so we need to + # explicitly attach chunk objects to the current session. + session = object_session(document) + if session is not None: + if document.id is not None: + for chunk in chunks: + chunk.document_id = document.id + session.add_all(chunks) + def get_current_timestamp() -> datetime: """ diff --git a/surfsense_backend/app/utils/indexing_locks.py b/surfsense_backend/app/utils/indexing_locks.py new file mode 100644 index 000000000..7790bcc11 --- /dev/null +++ b/surfsense_backend/app/utils/indexing_locks.py @@ -0,0 +1,46 @@ +"""Redis-based connector indexing locks to prevent duplicate sync tasks.""" + +import redis + +from app.config import config + +_redis_client: redis.Redis | None = None +LOCK_TTL_SECONDS = config.CONNECTOR_INDEXING_LOCK_TTL_SECONDS + + +def get_indexing_lock_redis_client() -> redis.Redis: + """Get or create Redis client for connector indexing locks.""" + global _redis_client + if _redis_client is None: + _redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True) + return _redis_client + + +def _get_connector_lock_key(connector_id: int) -> str: + """Generate Redis key for a connector indexing lock.""" + return f"indexing:connector_lock:{connector_id}" + + +def acquire_connector_indexing_lock(connector_id: int) -> bool: + """Acquire lock for connector indexing. Returns True if acquired.""" + key = _get_connector_lock_key(connector_id) + return bool( + get_indexing_lock_redis_client().set( + key, + "1", + nx=True, + ex=LOCK_TTL_SECONDS, + ) + ) + + +def release_connector_indexing_lock(connector_id: int) -> None: + """Release lock for connector indexing.""" + key = _get_connector_lock_key(connector_id) + get_indexing_lock_redis_client().delete(key) + + +def is_connector_indexing_locked(connector_id: int) -> bool: + """Check if connector indexing lock exists.""" + key = _get_connector_lock_key(connector_id) + return bool(get_indexing_lock_redis_client().exists(key))