From f7b52470eb4d1adcf77d8fd36e7416f0f0b5e850 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:33:44 +0530 Subject: [PATCH 01/31] feat: enhance Google connectors indexing with content extraction and document migration - Added `download_and_extract_content` function to extract content from Google Drive files as markdown. - Updated Google Drive indexer to utilize the new content extraction method. - Implemented document migration logic to update legacy Composio document types to their native Google types. - Introduced identifier hashing for stable document identification. - Improved file pre-filtering to handle unchanged and rename-only files efficiently. --- .../app/connectors/google_drive/__init__.py | 3 +- .../google_drive/content_extractor.py | 154 ++ .../app/indexing_pipeline/document_hashing.py | 11 +- .../indexing_pipeline_service.py | 59 +- .../google_calendar_indexer.py | 373 ++-- .../google_drive_indexer.py | 1539 +++++------------ .../google_gmail_indexer.py | 379 ++-- .../test_document_hashing.py | 21 + 8 files changed, 951 insertions(+), 1588 deletions(-) diff --git a/surfsense_backend/app/connectors/google_drive/__init__.py b/surfsense_backend/app/connectors/google_drive/__init__.py index 47cc8598e..a0e9c4484 100644 --- a/surfsense_backend/app/connectors/google_drive/__init__.py +++ b/surfsense_backend/app/connectors/google_drive/__init__.py @@ -2,13 +2,14 @@ from .change_tracker import categorize_change, fetch_all_changes, get_start_page_token from .client import GoogleDriveClient -from .content_extractor import download_and_process_file +from .content_extractor import download_and_extract_content, download_and_process_file from .credentials import get_valid_credentials, validate_credentials from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents __all__ = [ "GoogleDriveClient", "categorize_change", + "download_and_extract_content", "download_and_process_file", "fetch_all_changes", "get_file_by_id", diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 1d08d38f7..6fa20bf8e 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -17,6 +17,160 @@ from .file_types import get_export_mime_type, is_google_workspace_file, should_s logger = logging.getLogger(__name__) +async def download_and_extract_content( + client: GoogleDriveClient, + file: dict[str, Any], +) -> tuple[str | None, dict[str, Any], str | None]: + """Download a Google Drive file and extract its content as markdown. + + ETL only -- no DB writes, no indexing, no summarization. + + Returns: + (markdown_content, drive_metadata, error_message) + On success error_message is None. + """ + file_id = file.get("id") + file_name = file.get("name", "Unknown") + mime_type = file.get("mimeType", "") + + if should_skip_file(mime_type): + return None, {}, f"Skipping {mime_type}" + + logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})") + + drive_metadata: dict[str, Any] = { + "google_drive_file_id": file_id, + "google_drive_file_name": file_name, + "google_drive_mime_type": mime_type, + "source_connector": "google_drive", + } + if "modifiedTime" in file: + drive_metadata["modified_time"] = file["modifiedTime"] + if "createdTime" in file: + drive_metadata["created_time"] = file["createdTime"] + if "size" in file: + drive_metadata["file_size"] = file["size"] + if "webViewLink" in file: + drive_metadata["web_view_link"] = file["webViewLink"] + if "md5Checksum" in file: + drive_metadata["md5_checksum"] = file["md5Checksum"] + if is_google_workspace_file(mime_type): + drive_metadata["exported_as"] = "pdf" + drive_metadata["original_workspace_type"] = mime_type.split(".")[-1] + + temp_file_path = None + try: + # Download / export + if is_google_workspace_file(mime_type): + export_mime = get_export_mime_type(mime_type) + if not export_mime: + return None, drive_metadata, f"Cannot export Google Workspace type: {mime_type}" + content_bytes, error = await client.export_google_file(file_id, export_mime) + if error: + return None, drive_metadata, error + extension = ".pdf" if export_mime == "application/pdf" else ".txt" + else: + content_bytes, error = await client.download_file(file_id) + if error: + return None, drive_metadata, error + extension = Path(file_name).suffix or ".bin" + + with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp: + tmp.write(content_bytes) + temp_file_path = tmp.name + + # Parse to markdown + markdown = await _parse_file_to_markdown(temp_file_path, file_name) + return markdown, drive_metadata, None + + except Exception as e: + logger.warning(f"Failed to extract content from {file_name}: {e!s}") + return None, drive_metadata, str(e) + finally: + if temp_file_path and os.path.exists(temp_file_path): + try: + os.unlink(temp_file_path) + except Exception: + pass + + +async def _parse_file_to_markdown(file_path: str, filename: str) -> str: + """Parse a local file to markdown using the configured ETL service.""" + lower = filename.lower() + + if lower.endswith((".md", ".markdown", ".txt")): + with open(file_path, encoding="utf-8") as f: + return f.read() + + if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")): + from app.config import config as app_config + from litellm import atranscription + + stt_service_type = ( + "local" + if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/") + else "external" + ) + if stt_service_type == "local": + from app.services.stt_service import stt_service + result = stt_service.transcribe_file(file_path) + text = result.get("text", "") + else: + with open(file_path, "rb") as audio_file: + kwargs: dict[str, Any] = { + "model": app_config.STT_SERVICE, + "file": audio_file, + "api_key": app_config.STT_SERVICE_API_KEY, + } + if app_config.STT_SERVICE_API_BASE: + kwargs["api_base"] = app_config.STT_SERVICE_API_BASE + resp = await atranscription(**kwargs) + text = resp.get("text", "") + + if not text: + raise ValueError("Transcription returned empty text") + return f"# Transcription of {filename}\n\n{text}" + + # Document files -- use configured ETL service + from app.config import config as app_config + + if app_config.ETL_SERVICE == "UNSTRUCTURED": + from langchain_unstructured import UnstructuredLoader + from app.utils.document_converters import convert_document_to_markdown + + loader = UnstructuredLoader( + file_path, + mode="elements", + post_processors=[], + languages=["eng"], + include_orig_elements=False, + include_metadata=False, + strategy="auto", + ) + docs = await loader.aload() + return await convert_document_to_markdown(docs) + + if app_config.ETL_SERVICE == "LLAMACLOUD": + from app.tasks.document_processors.file_processors import ( + parse_with_llamacloud_retry, + ) + + result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50) + markdown_documents = await result.aget_markdown_documents(split_by_page=False) + if not markdown_documents: + raise RuntimeError(f"LlamaCloud returned no documents for {filename}") + return markdown_documents[0].text + + if app_config.ETL_SERVICE == "DOCLING": + from docling.document_converter import DocumentConverter + + converter = DocumentConverter() + result = converter.convert(file_path) + return result.document.export_to_markdown() + + raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}") + + async def download_and_process_file( client: GoogleDriveClient, file: dict[str, Any], diff --git a/surfsense_backend/app/indexing_pipeline/document_hashing.py b/surfsense_backend/app/indexing_pipeline/document_hashing.py index 5dd7767a4..9edebd140 100644 --- a/surfsense_backend/app/indexing_pipeline/document_hashing.py +++ b/surfsense_backend/app/indexing_pipeline/document_hashing.py @@ -3,10 +3,17 @@ import hashlib from app.indexing_pipeline.connector_document import ConnectorDocument +def compute_identifier_hash( + document_type_value: str, unique_id: str, search_space_id: int +) -> str: + """Return a stable SHA-256 hash from raw identity components.""" + combined = f"{document_type_value}:{unique_id}:{search_space_id}" + return hashlib.sha256(combined.encode("utf-8")).hexdigest() + + def compute_unique_identifier_hash(doc: ConnectorDocument) -> str: """Return a stable SHA-256 hash identifying a document by its source identity.""" - combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}" - return hashlib.sha256(combined.encode("utf-8")).hexdigest() + return compute_identifier_hash(doc.document_type.value, doc.unique_id, doc.search_space_id) def compute_content_hash(doc: ConnectorDocument) -> str: diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 490aac782..c6a29f204 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -6,12 +6,13 @@ from sqlalchemy import delete, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from app.db import Chunk, Document, DocumentStatus +from app.db import NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, DocumentStatus from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_chunker import chunk_text from app.indexing_pipeline.document_embedder import embed_texts from app.indexing_pipeline.document_hashing import ( compute_content_hash, + compute_identifier_hash, compute_unique_identifier_hash, ) from app.indexing_pipeline.document_persistence import ( @@ -54,6 +55,62 @@ class IndexingPipelineService: def __init__(self, session: AsyncSession) -> None: self.session = session + async def migrate_legacy_docs( + self, connector_docs: list[ConnectorDocument] + ) -> None: + """Migrate legacy Composio documents to their native Google type. + + For each ConnectorDocument whose document_type has a Composio equivalent + in NATIVE_TO_LEGACY_DOCTYPE, look up the old document by legacy hash and + update its unique_identifier_hash and document_type so that + prepare_for_indexing() can find it under the native hash. + """ + for doc in connector_docs: + legacy_type = NATIVE_TO_LEGACY_DOCTYPE.get(doc.document_type.value) + if not legacy_type: + continue + + legacy_hash = compute_identifier_hash( + legacy_type, doc.unique_id, doc.search_space_id + ) + result = await self.session.execute( + select(Document).filter( + Document.unique_identifier_hash == legacy_hash + ) + ) + existing = result.scalars().first() + if existing is None: + continue + + native_hash = compute_identifier_hash( + doc.document_type.value, doc.unique_id, doc.search_space_id + ) + existing.unique_identifier_hash = native_hash + existing.document_type = doc.document_type + + await self.session.commit() + + async def index_batch( + self, connector_docs: list[ConnectorDocument], llm + ) -> list[Document]: + """Convenience method: prepare_for_indexing then index each document. + + Indexers that need heartbeat callbacks or custom per-document logic + should call prepare_for_indexing() + index() directly instead. + """ + doc_map = { + compute_unique_identifier_hash(cd): cd for cd in connector_docs + } + documents = await self.prepare_for_indexing(connector_docs) + results: list[Document] = [] + for document in documents: + connector_doc = doc_map.get(document.unique_identifier_hash) + if connector_doc is None: + continue + result = await self.index(document, connector_doc, llm) + results.append(result) + return results + async def prepare_for_indexing( self, connector_docs: list[ConnectorDocument] ) -> list[Document]: diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index 233bc66e4..a69b33bdc 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -1,9 +1,8 @@ """ Google Calendar connector indexer. -Implements 2-phase document status updates for real-time UI feedback: -- Phase 1: Create all documents with 'pending' status (visible in UI immediately) -- Phase 2: Process each document: pending → processing → ready/failed +Uses the shared IndexingPipelineService for document deduplication, +summarization, chunking, and embedding. """ import time @@ -15,29 +14,25 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.google_calendar_connector import GoogleCalendarConnector -from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType +from app.db import DocumentType, SearchSourceConnectorType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import ( + compute_content_hash, + compute_unique_identifier_hash, +) +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.document_converters import ( - create_document_chunks, - embed_text, - generate_content_hash, - generate_document_summary, - generate_unique_identifier_hash, -) from app.utils.google_credentials import ( COMPOSIO_GOOGLE_CONNECTOR_TYPES, build_composio_credentials, ) from .base import ( - check_document_by_unique_identifier, check_duplicate_document_by_hash, get_connector_by_id, - get_current_timestamp, logger, parse_date_flexible, - safe_set_chunks, update_connector_last_indexed, ) @@ -46,13 +41,60 @@ ACCEPTED_CALENDAR_CONNECTOR_TYPES = { SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, } -# Type hint for heartbeat callback HeartbeatCallbackType = Callable[[int], Awaitable[None]] - -# Heartbeat interval in seconds HEARTBEAT_INTERVAL_SECONDS = 30 +def _build_connector_doc( + event: dict, + event_markdown: str, + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, +) -> ConnectorDocument: + """Map a raw Google Calendar API event dict to a ConnectorDocument.""" + event_id = event.get("id", "") + event_summary = event.get("summary", "No Title") + calendar_id = event.get("calendarId", "") + + start = event.get("start", {}) + end = event.get("end", {}) + start_time = start.get("dateTime") or start.get("date", "") + end_time = end.get("dateTime") or end.get("date", "") + location = event.get("location", "") + + metadata = { + "event_id": event_id, + "event_summary": event_summary, + "calendar_id": calendar_id, + "start_time": start_time, + "end_time": end_time, + "location": location, + "connector_id": connector_id, + "document_type": "Google Calendar Event", + "connector_type": "Google Calendar", + } + + fallback_summary = ( + f"Google Calendar Event: {event_summary}\n\n{event_markdown}" + ) + + return ConnectorDocument( + title=event_summary, + source_markdown=event_markdown, + unique_id=event_id, + document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=enable_summary, + fallback_summary=fallback_summary, + metadata=metadata, + ) + + async def index_google_calendar_events( session: AsyncSession, connector_id: int, @@ -82,7 +124,6 @@ async def index_google_calendar_events( """ task_logger = TaskLoggingService(session, search_space_id) - # Log task start log_entry = await task_logger.log_task_start( task_name="google_calendar_events_indexing", source="connector_indexing_task", @@ -96,7 +137,7 @@ async def index_google_calendar_events( ) try: - # Accept both native and Composio Calendar connectors + # ── Connector lookup ────────────────────────────────────────── connector = None for ct in ACCEPTED_CALENDAR_CONNECTOR_TYPES: connector = await get_connector_by_id(session, connector_id, ct) @@ -112,7 +153,7 @@ async def index_google_calendar_events( ) return 0, 0, f"Connector with ID {connector_id} not found" - # Build credentials based on connector type + # ── Credential building ─────────────────────────────────────── if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: @@ -184,6 +225,7 @@ async def index_google_calendar_events( ) return 0, 0, "Google Calendar credentials not found in connector config" + # ── Calendar client init ────────────────────────────────────── await task_logger.log_task_progress( log_entry, f"Initializing Google Calendar client for connector {connector_id}", @@ -203,36 +245,26 @@ async def index_google_calendar_events( if end_date == "undefined" or end_date == "": end_date = None - # Calculate date range - # For calendar connectors, allow future dates to index upcoming events + # ── Date range calculation ──────────────────────────────────── if start_date is None or end_date is None: - # Fall back to calculating dates based on last_indexed_at - # Default to today (users can manually select future dates if needed) calculated_end_date = datetime.now() - # Use last_indexed_at as start date if available, otherwise use 30 days ago if connector.last_indexed_at: - # Convert dates to be comparable (both timezone-naive) last_indexed_naive = ( connector.last_indexed_at.replace(tzinfo=None) if connector.last_indexed_at.tzinfo else connector.last_indexed_at ) - - # Allow future dates - use last_indexed_at as start date calculated_start_date = last_indexed_naive logger.info( f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date" ) else: - calculated_start_date = datetime.now() - timedelta( - days=365 - ) # Use 365 days as default for calendar events (matches frontend) + calculated_start_date = datetime.now() - timedelta(days=365) logger.info( f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date" ) - # Use calculated dates if not provided start_date_str = ( start_date if start_date else calculated_start_date.strftime("%Y-%m-%d") ) @@ -240,19 +272,14 @@ async def index_google_calendar_events( end_date if end_date else calculated_end_date.strftime("%Y-%m-%d") ) else: - # Use provided dates (including future dates) start_date_str = start_date end_date_str = end_date - # FIX: Ensure end_date is at least 1 day after start_date to avoid - # "start_date must be strictly before end_date" errors when dates are the same - # (e.g., when last_indexed_at is today) if start_date_str == end_date_str: logger.info( f"Start date ({start_date_str}) equals end date ({end_date_str}), " "adjusting end date to next day to ensure valid date range" ) - # Parse end_date and add 1 day try: end_dt = parse_date_flexible(end_date_str) except ValueError: @@ -264,6 +291,7 @@ async def index_google_calendar_events( end_date_str = end_dt.strftime("%Y-%m-%d") logger.info(f"Adjusted end date to {end_date_str}") + # ── Fetch events ────────────────────────────────────────────── await task_logger.log_task_progress( log_entry, f"Fetching Google Calendar events from {start_date_str} to {end_date_str}", @@ -274,27 +302,19 @@ async def index_google_calendar_events( }, ) - # Get events within date range from primary calendar try: events, error = await calendar_client.get_all_primary_calendar_events( start_date=start_date_str, end_date=end_date_str ) if error: - # Don't treat "No events found" as an error that should stop indexing if "No events found" in error: logger.info(f"No Google Calendar events found: {error}") - logger.info( - "No events found is not a critical error, continuing with update" - ) if update_last_indexed: await update_connector_last_indexed( session, connector, update_last_indexed ) await session.commit() - logger.info( - f"Updated last_indexed_at to {connector.last_indexed_at} despite no events found" - ) await task_logger.log_task_success( log_entry, @@ -304,7 +324,6 @@ async def index_google_calendar_events( return 0, 0, None else: logger.error(f"Failed to get Google Calendar events: {error}") - # Check if this is an authentication error that requires re-authentication error_message = error error_type = "APIError" if ( @@ -329,28 +348,15 @@ async def index_google_calendar_events( logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True) return 0, 0, f"Error fetching Google Calendar events: {e!s}" - documents_indexed = 0 + # ── Build ConnectorDocuments ────────────────────────────────── + connector_docs: list[ConnectorDocument] = [] documents_skipped = 0 - documents_failed = 0 # Track events that failed processing - duplicate_content_count = ( - 0 # Track events skipped due to duplicate content_hash - ) - - # Heartbeat tracking - update notification periodically to prevent appearing stuck - last_heartbeat_time = time.time() - - # ======================================================================= - # PHASE 1: Analyze all events, create pending documents - # This makes ALL documents visible in the UI immediately with pending status - # ======================================================================= - events_to_process = [] # List of dicts with document and event data - new_documents_created = False + duplicate_content_count = 0 for event in events: try: event_id = event.get("id") event_summary = event.get("summary", "No Title") - calendar_id = event.get("calendarId", "") if not event_id: logger.warning(f"Skipping event with missing ID: {event_summary}") @@ -363,223 +369,73 @@ async def index_google_calendar_events( documents_skipped += 1 continue - start = event.get("start", {}) - end = event.get("end", {}) - start_time = start.get("dateTime") or start.get("date", "") - end_time = end.get("dateTime") or end.get("date", "") - location = event.get("location", "") - description = event.get("description", "") - - # Generate unique identifier hash for this Google Calendar event - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.GOOGLE_CALENDAR_CONNECTOR, event_id, search_space_id + doc = _build_connector_doc( + event, + event_markdown, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=connector.enable_summary, ) - # Generate content hash - content_hash = generate_content_hash(event_markdown, search_space_id) - - # Check if document with this unique identifier already exists - existing_document = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - - # Fallback: legacy Composio hash - if not existing_document: - legacy_hash = generate_unique_identifier_hash( - DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - event_id, - search_space_id, - ) - existing_document = await check_document_by_unique_identifier( - session, legacy_hash - ) - if existing_document: - existing_document.unique_identifier_hash = ( - unique_identifier_hash - ) - if ( - existing_document.document_type - == DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR - ): - existing_document.document_type = ( - DocumentType.GOOGLE_CALENDAR_CONNECTOR - ) - logger.info( - f"Migrated legacy Composio Calendar document: {event_id}" - ) - - if existing_document: - # Document exists - check if content has changed - if existing_document.content_hash == content_hash: - # Ensure status is ready (might have been stuck in processing/pending) - if not DocumentStatus.is_state( - existing_document.status, DocumentStatus.READY - ): - existing_document.status = DocumentStatus.ready() - documents_skipped += 1 - continue - - # Queue existing document for update (will be set to processing in Phase 2) - events_to_process.append( - { - "document": existing_document, - "is_new": False, - "event_markdown": event_markdown, - "content_hash": content_hash, - "event_id": event_id, - "event_summary": event_summary, - "calendar_id": calendar_id, - "start_time": start_time, - "end_time": end_time, - "location": location, - "description": description, - } - ) - continue - - # Document doesn't exist by unique_identifier_hash - # Check if a document with the same content_hash exists (from another connector) with session.no_autoflush: - duplicate_by_content = await check_duplicate_document_by_hash( - session, content_hash + duplicate = await check_duplicate_document_by_hash( + session, compute_content_hash(doc) ) - - if duplicate_by_content: - # A document with the same content already exists (likely from Composio connector) + if duplicate: logger.info( - f"Event {event_summary} already indexed by another connector " - f"(existing document ID: {duplicate_by_content.id}, " - f"type: {duplicate_by_content.document_type}). Skipping to avoid duplicate content." + f"Event {doc.title} already indexed by another connector " + f"(existing document ID: {duplicate.id}, " + f"type: {duplicate.document_type}). Skipping." ) duplicate_content_count += 1 documents_skipped += 1 continue - # Create new document with PENDING status (visible in UI immediately) - document = Document( - search_space_id=search_space_id, - title=event_summary, - document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR, - document_metadata={ - "event_id": event_id, - "event_summary": event_summary, - "calendar_id": calendar_id, - "start_time": start_time, - "end_time": end_time, - "location": location, - "connector_id": connector_id, - }, - content="Pending...", # Placeholder until processed - content_hash=unique_identifier_hash, # Temporary unique value - updated when ready - unique_identifier_hash=unique_identifier_hash, - embedding=None, - chunks=[], # Empty at creation - safe for async - status=DocumentStatus.pending(), # Pending until processing starts - updated_at=get_current_timestamp(), - created_by_id=user_id, - connector_id=connector_id, - ) - session.add(document) - new_documents_created = True - - events_to_process.append( - { - "document": document, - "is_new": True, - "event_markdown": event_markdown, - "content_hash": content_hash, - "event_id": event_id, - "event_summary": event_summary, - "calendar_id": calendar_id, - "start_time": start_time, - "end_time": end_time, - "location": location, - "description": description, - } - ) + connector_docs.append(doc) except Exception as e: - logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True) - documents_failed += 1 + logger.error(f"Error building ConnectorDocument for event: {e!s}", exc_info=True) + documents_skipped += 1 continue - # Commit all pending documents - they all appear in UI now - if new_documents_created: - logger.info( - f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents" - ) - await session.commit() + # ── Pipeline: migrate legacy docs + prepare + index ─────────── + pipeline = IndexingPipelineService(session) - # ======================================================================= - # PHASE 2: Process each document one by one - # Each document transitions: pending → processing → ready/failed - # ======================================================================= - logger.info(f"Phase 2: Processing {len(events_to_process)} documents") + await pipeline.migrate_legacy_docs(connector_docs) - for item in events_to_process: - # Send heartbeat periodically + documents = await pipeline.prepare_for_indexing(connector_docs) + + doc_map = { + compute_unique_identifier_hash(cd): cd for cd in connector_docs + } + + documents_indexed = 0 + documents_failed = 0 + last_heartbeat_time = time.time() + + for document in documents: if on_heartbeat_callback: current_time = time.time() if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: await on_heartbeat_callback(documents_indexed) last_heartbeat_time = current_time - document = item["document"] - try: - # Set to PROCESSING and commit - shows "processing" in UI for THIS document only - document.status = DocumentStatus.processing() - await session.commit() + connector_doc = doc_map.get(document.unique_identifier_hash) + if connector_doc is None: + logger.warning( + f"No matching ConnectorDocument for document {document.id}, skipping" + ) + documents_failed += 1 + continue - # Heavy processing (LLM, embeddings, chunks) + try: user_llm = await get_user_long_context_llm( session, user_id, search_space_id ) - - if user_llm and connector.enable_summary: - document_metadata_for_summary = { - "event_id": item["event_id"], - "event_summary": item["event_summary"], - "calendar_id": item["calendar_id"], - "start_time": item["start_time"], - "end_time": item["end_time"], - "location": item["location"] or "No location", - "document_type": "Google Calendar Event", - "connector_type": "Google Calendar", - } - ( - summary_content, - summary_embedding, - ) = await generate_document_summary( - item["event_markdown"], user_llm, document_metadata_for_summary - ) - else: - summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}" - summary_embedding = embed_text(summary_content) - - chunks = await create_document_chunks(item["event_markdown"]) - - # Update document to READY with actual content - document.title = item["event_summary"] - document.content = summary_content - document.content_hash = item["content_hash"] - document.embedding = summary_embedding - document.document_metadata = { - "event_id": item["event_id"], - "event_summary": item["event_summary"], - "calendar_id": item["calendar_id"], - "start_time": item["start_time"], - "end_time": item["end_time"], - "location": item["location"], - "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "connector_id": connector_id, - } - await safe_set_chunks(session, document, chunks) - document.updated_at = get_current_timestamp() - document.status = DocumentStatus.ready() - + await pipeline.index(document, connector_doc, user_llm) documents_indexed += 1 - # Batch commit every 10 documents (for ready status updates) if documents_indexed % 10 == 0: logger.info( f"Committing batch: {documents_indexed} Google Calendar events processed so far" @@ -588,21 +444,12 @@ async def index_google_calendar_events( except Exception as e: logger.error(f"Error processing Calendar event: {e!s}", exc_info=True) - # Mark document as failed with reason (visible in UI) - try: - document.status = DocumentStatus.failed(str(e)) - document.updated_at = get_current_timestamp() - except Exception as status_error: - logger.error( - f"Failed to update document status to failed: {status_error}" - ) documents_failed += 1 continue - # CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs + # ── Finalize ────────────────────────────────────────────────── await update_connector_last_indexed(session, connector, update_last_indexed) - # Final commit for any remaining documents not yet committed in batches logger.info( f"Final commit: Total {documents_indexed} Google Calendar events processed" ) @@ -612,22 +459,18 @@ async def index_google_calendar_events( "Successfully committed all Google Calendar document changes to database" ) except Exception as e: - # Handle any remaining integrity errors gracefully (race conditions, etc.) if ( "duplicate key value violates unique constraint" in str(e).lower() or "uniqueviolationerror" in str(e).lower() ): logger.warning( f"Duplicate content_hash detected during final commit. " - f"This may occur if the same event was indexed by multiple connectors. " f"Rolling back and continuing. Error: {e!s}" ) await session.rollback() - # Don't fail the entire task - some documents may have been successfully indexed else: raise - # Build warning message if there were issues warning_parts = [] if duplicate_content_count > 0: warning_parts.append(f"{duplicate_content_count} duplicate") diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 260db0ce6..92c074812 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -1,36 +1,41 @@ -"""Google Drive indexer using Surfsense file processors. +"""Google Drive indexer using the shared IndexingPipelineService. -Implements 2-phase document status updates for real-time UI feedback: -- Phase 1: Create all documents with 'pending' status (visible in UI immediately) -- Phase 2: Process each document: pending → processing → ready/failed +File-level pre-filter (_should_skip_file) handles md5/modifiedTime +checks and rename-only detection. download_and_extract_content() +returns markdown which is fed into ConnectorDocument -> pipeline. """ import logging import time from collections.abc import Awaitable, Callable +from sqlalchemy import String, cast, select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import flag_modified from app.config import config from app.connectors.google_drive import ( GoogleDriveClient, categorize_change, - download_and_process_file, + download_and_extract_content, fetch_all_changes, get_file_by_id, get_files_in_folder, get_start_page_token, ) +from app.connectors.google_drive.file_types import should_skip_file as skip_mime from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_identifier_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService +from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, get_connector_by_id, - get_current_timestamp, update_connector_last_indexed, ) -from app.utils.document_converters import generate_unique_identifier_hash from app.utils.google_credentials import ( COMPOSIO_GOOGLE_CONNECTOR_TYPES, build_composio_credentials, @@ -41,15 +46,423 @@ ACCEPTED_DRIVE_CONNECTOR_TYPES = { SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, } -# Type hint for heartbeat callback HeartbeatCallbackType = Callable[[int], Awaitable[None]] - -# Heartbeat interval in seconds HEARTBEAT_INTERVAL_SECONDS = 30 logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +async def _should_skip_file( + session: AsyncSession, + file: dict, + search_space_id: int, +) -> tuple[bool, str | None]: + """Pre-filter: detect unchanged / rename-only files. + + Returns (should_skip, message). + Side-effects: migrates legacy Composio hashes, updates renames in-place. + """ + file_id = file.get("id") + file_name = file.get("name", "Unknown") + mime_type = file.get("mimeType", "") + + if skip_mime(mime_type): + return True, "folder/shortcut" + if not file_id: + return True, "missing file_id" + + # --- locate existing document --- + primary_hash = compute_identifier_hash( + DocumentType.GOOGLE_DRIVE_FILE.value, file_id, search_space_id + ) + existing = await check_document_by_unique_identifier(session, primary_hash) + + if not existing: + legacy_hash = compute_identifier_hash( + DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR.value, file_id, search_space_id + ) + existing = await check_document_by_unique_identifier(session, legacy_hash) + if existing: + existing.unique_identifier_hash = primary_hash + if existing.document_type == DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: + existing.document_type = DocumentType.GOOGLE_DRIVE_FILE + logger.info(f"Migrated legacy Composio Drive document: {file_id}") + + if not existing: + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.document_type.in_([ + DocumentType.GOOGLE_DRIVE_FILE, + DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + ]), + cast(Document.document_metadata["google_drive_file_id"], String) == file_id, + ) + ) + existing = result.scalar_one_or_none() + if existing: + existing.unique_identifier_hash = primary_hash + if existing.document_type == DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: + existing.document_type = DocumentType.GOOGLE_DRIVE_FILE + logger.debug(f"Found legacy doc by metadata for file_id: {file_id}") + + if not existing: + return False, None + + # --- content-change check via md5 / modifiedTime --- + incoming_md5 = file.get("md5Checksum") + incoming_mtime = file.get("modifiedTime") + meta = existing.document_metadata or {} + stored_md5 = meta.get("md5_checksum") + stored_mtime = meta.get("modified_time") + + content_unchanged = False + if incoming_md5 and stored_md5: + content_unchanged = incoming_md5 == stored_md5 + elif incoming_md5 and not stored_md5: + return False, None + elif not incoming_md5 and incoming_mtime and stored_mtime: + content_unchanged = incoming_mtime == stored_mtime + elif not incoming_md5: + return False, None + + if not content_unchanged: + return False, None + + # --- rename-only detection --- + old_name = meta.get("FILE_NAME") or meta.get("google_drive_file_name") + if old_name and old_name != file_name: + existing.title = file_name + if not existing.document_metadata: + existing.document_metadata = {} + existing.document_metadata["FILE_NAME"] = file_name + existing.document_metadata["google_drive_file_name"] = file_name + if incoming_mtime: + existing.document_metadata["modified_time"] = incoming_mtime + flag_modified(existing, "document_metadata") + await session.commit() + logger.info(f"Rename-only update: '{old_name}' → '{file_name}'") + return True, f"File renamed: '{old_name}' → '{file_name}'" + + if not DocumentStatus.is_state(existing.status, DocumentStatus.READY): + existing.status = DocumentStatus.ready() + return True, "unchanged" + + +def _build_connector_doc( + file: dict, + markdown: str, + drive_metadata: dict, + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, +) -> ConnectorDocument: + """Build a ConnectorDocument from Drive file metadata + extracted markdown.""" + file_id = file.get("id", "") + file_name = file.get("name", "Unknown") + + metadata = { + **drive_metadata, + "connector_id": connector_id, + "document_type": "Google Drive File", + "connector_type": "Google Drive", + } + + fallback_summary = f"File: {file_name}\n\n{markdown[:4000]}" + + return ConnectorDocument( + title=file_name, + source_markdown=markdown, + unique_id=file_id, + document_type=DocumentType.GOOGLE_DRIVE_FILE, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=enable_summary, + fallback_summary=fallback_summary, + metadata=metadata, + ) + + +async def _process_single_file( + drive_client: GoogleDriveClient, + session: AsyncSession, + file: dict, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool = True, +) -> tuple[int, int, int]: + """Download, extract, and index a single Drive file via the pipeline. + + Returns (indexed, skipped, failed). + """ + file_name = file.get("name", "Unknown") + + try: + skip, msg = await _should_skip_file(session, file, search_space_id) + if skip: + if msg and "renamed" in msg.lower(): + return 1, 0, 0 + return 0, 1, 0 + + markdown, drive_metadata, error = await download_and_extract_content( + drive_client, file + ) + if error or not markdown: + logger.warning(f"ETL failed for {file_name}: {error}") + return 0, 1, 0 + + doc = _build_connector_doc( + file, + markdown, + drive_metadata, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=enable_summary, + ) + + pipeline = IndexingPipelineService(session) + documents = await pipeline.prepare_for_indexing([doc]) + if not documents: + return 0, 1, 0 + + from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash + + doc_map = {compute_unique_identifier_hash(doc): doc} + for document in documents: + connector_doc = doc_map.get(document.unique_identifier_hash) + if not connector_doc: + continue + user_llm = await get_user_long_context_llm(session, user_id, search_space_id) + await pipeline.index(document, connector_doc, user_llm) + + logger.info(f"Successfully indexed Google Drive file: {file_name}") + return 1, 0, 0 + + except Exception as e: + logger.error(f"Error processing file {file_name}: {e!s}", exc_info=True) + return 0, 0, 1 + + +async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int): + """Remove a document that was deleted in Drive.""" + primary_hash = compute_identifier_hash( + DocumentType.GOOGLE_DRIVE_FILE.value, file_id, search_space_id + ) + existing = await check_document_by_unique_identifier(session, primary_hash) + + if not existing: + legacy_hash = compute_identifier_hash( + DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR.value, file_id, search_space_id + ) + existing = await check_document_by_unique_identifier(session, legacy_hash) + + if not existing: + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.document_type.in_([ + DocumentType.GOOGLE_DRIVE_FILE, + DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + ]), + cast(Document.document_metadata["google_drive_file_id"], String) == file_id, + ) + ) + existing = result.scalar_one_or_none() + + if existing: + await session.delete(existing) + logger.info(f"Removed deleted file document: {file_id}") + + +# --------------------------------------------------------------------------- +# Scan strategies +# --------------------------------------------------------------------------- + +async def _index_full_scan( + drive_client: GoogleDriveClient, + session: AsyncSession, + connector: object, + connector_id: int, + search_space_id: int, + user_id: str, + folder_id: str | None, + folder_name: str, + task_logger: TaskLoggingService, + log_entry: object, + max_files: int, + include_subfolders: bool = False, + on_heartbeat_callback: HeartbeatCallbackType | None = None, + enable_summary: bool = True, +) -> tuple[int, int]: + """Full scan indexing of a folder.""" + await task_logger.log_task_progress( + log_entry, + f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})", + {"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders}, + ) + + indexed = 0 + skipped = 0 + failed = 0 + files_processed = 0 + last_heartbeat = time.time() + folders_to_process = [(folder_id, folder_name)] + first_error: str | None = None + + while folders_to_process and files_processed < max_files: + cur_id, cur_name = folders_to_process.pop(0) + page_token = None + + while files_processed < max_files: + files, next_token, error = await get_files_in_folder( + drive_client, cur_id, include_subfolders=True, page_token=page_token, + ) + if error: + logger.error(f"Error listing files in {cur_name}: {error}") + if first_error is None: + first_error = error + break + if not files: + break + + for file in files: + if files_processed >= max_files: + break + + mime = file.get("mimeType", "") + if mime == "application/vnd.google-apps.folder": + if include_subfolders: + folders_to_process.append((file["id"], file.get("name", "Unknown"))) + continue + + files_processed += 1 + + if on_heartbeat_callback: + now = time.time() + if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS: + await on_heartbeat_callback(indexed) + last_heartbeat = now + + i, s, f = await _process_single_file( + drive_client, session, file, + connector_id, search_space_id, user_id, enable_summary, + ) + indexed += i + skipped += s + failed += f + + if indexed > 0 and indexed % 10 == 0: + await session.commit() + + page_token = next_token + if not page_token: + break + + if not files_processed and first_error: + err_lower = first_error.lower() + if "401" in first_error or "invalid credentials" in err_lower or "authError" in first_error: + raise Exception( + f"Google Drive authentication failed. Please re-authenticate. (Error: {first_error})" + ) + raise Exception(f"Failed to list Google Drive files: {first_error}") + + logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed") + return indexed, skipped + + +async def _index_with_delta_sync( + drive_client: GoogleDriveClient, + session: AsyncSession, + connector: object, + connector_id: int, + search_space_id: int, + user_id: str, + folder_id: str | None, + start_page_token: str, + task_logger: TaskLoggingService, + log_entry: object, + max_files: int, + include_subfolders: bool = False, + on_heartbeat_callback: HeartbeatCallbackType | None = None, + enable_summary: bool = True, +) -> tuple[int, int]: + """Delta sync using change tracking.""" + await task_logger.log_task_progress( + log_entry, + f"Starting delta sync from token: {start_page_token[:20]}...", + {"stage": "delta_sync", "start_token": start_page_token}, + ) + + changes, _final_token, error = await fetch_all_changes(drive_client, start_page_token, folder_id) + if error: + err_lower = error.lower() + if "401" in error or "invalid credentials" in err_lower or "authError" in error: + raise Exception( + f"Google Drive authentication failed. Please re-authenticate. (Error: {error})" + ) + raise Exception(f"Failed to fetch Google Drive changes: {error}") + + if not changes: + logger.info("No changes detected since last sync") + return 0, 0 + + logger.info(f"Processing {len(changes)} changes") + indexed = 0 + skipped = 0 + failed = 0 + files_processed = 0 + last_heartbeat = time.time() + + for change in changes: + if files_processed >= max_files: + break + files_processed += 1 + change_type = categorize_change(change) + + if change_type in ["removed", "trashed"]: + fid = change.get("fileId") + if fid: + await _remove_document(session, fid, search_space_id) + continue + + file = change.get("file") + if not file: + continue + + if on_heartbeat_callback: + now = time.time() + if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS: + await on_heartbeat_callback(indexed) + last_heartbeat = now + + i, s, f = await _process_single_file( + drive_client, session, file, + connector_id, search_space_id, user_id, enable_summary, + ) + indexed += i + skipped += s + failed += f + + if indexed > 0 and indexed % 10 == 0: + await session.commit() + + logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed") + return indexed, skipped + + +# --------------------------------------------------------------------------- +# Public entry points +# --------------------------------------------------------------------------- + async def index_google_drive_files( session: AsyncSession, connector_id: int, @@ -63,234 +476,125 @@ async def index_google_drive_files( include_subfolders: bool = False, on_heartbeat_callback: HeartbeatCallbackType | None = None, ) -> tuple[int, int, str | None]: - """ - Index Google Drive files for a specific connector. - - Args: - session: Database session - connector_id: ID of the Drive connector - search_space_id: ID of the search space - user_id: ID of the user - folder_id: Specific folder to index (from UI/request, takes precedence) - folder_name: Folder name for display (from UI/request) - use_delta_sync: Whether to use change tracking for incremental sync - update_last_indexed: Whether to update last_indexed_at timestamp - max_files: Maximum number of files to index - include_subfolders: Whether to recursively index files in subfolders - on_heartbeat_callback: Optional callback to update notification during long-running indexing. - - Returns: - Tuple of (number_of_indexed_files, number_of_skipped_files, error_message) - """ + """Index Google Drive files for a specific connector.""" task_logger = TaskLoggingService(session, search_space_id) - log_entry = await task_logger.log_task_start( task_name="google_drive_files_indexing", source="connector_indexing_task", message=f"Starting Google Drive indexing for connector {connector_id}", metadata={ - "connector_id": connector_id, - "user_id": str(user_id), - "folder_id": folder_id, - "use_delta_sync": use_delta_sync, - "max_files": max_files, + "connector_id": connector_id, "user_id": str(user_id), + "folder_id": folder_id, "use_delta_sync": use_delta_sync, "max_files": max_files, }, ) try: - # Accept both native and Composio Drive connectors connector = None for ct in ACCEPTED_DRIVE_CONNECTOR_TYPES: connector = await get_connector_by_id(session, connector_id, ct) if connector: break - if not connector: error_msg = f"Google Drive connector with ID {connector_id} not found" - await task_logger.log_task_failure( - log_entry, error_msg, None, {"error_type": "ConnectorNotFound"} - ) + await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}) return 0, 0, error_msg await task_logger.log_task_progress( - log_entry, - f"Initializing Google Drive client for connector {connector_id}", + log_entry, f"Initializing Google Drive client for connector {connector_id}", {"stage": "client_initialization"}, ) - # Build credentials based on connector type pre_built_credentials = None if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) + await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"}) return 0, 0, error_msg pre_built_credentials = build_composio_credentials(connected_account_id) else: token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted: - if not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}", - "Missing SECRET_KEY for token decryption", - {"error_type": "MissingSecretKey"}, - ) - return ( - 0, - 0, - "SECRET_KEY not configured but credentials are marked as encrypted", - ) - logger.info( - f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization" + if token_encrypted and not config.SECRET_KEY: + await task_logger.log_task_failure( + log_entry, "SECRET_KEY not configured but credentials are encrypted", + "Missing SECRET_KEY", {"error_type": "MissingSecretKey"}, ) + return 0, 0, "SECRET_KEY not configured but credentials are marked as encrypted" connector_enable_summary = getattr(connector, "enable_summary", True) - - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) + drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials) if not folder_id: error_msg = "folder_id is required for Google Drive indexing" - await task_logger.log_task_failure( - log_entry, error_msg, {"error_type": "MissingParameter"} - ) + await task_logger.log_task_failure(log_entry, error_msg, {"error_type": "MissingParameter"}) return 0, 0, error_msg target_folder_id = folder_id target_folder_name = folder_name or "Selected Folder" - logger.info( - f"Indexing Google Drive folder: {target_folder_name} ({target_folder_id})" - ) - folder_tokens = connector.config.get("folder_tokens", {}) start_page_token = folder_tokens.get(target_folder_id) - can_use_delta_sync = ( - use_delta_sync and start_page_token and connector.last_indexed_at - ) + can_use_delta = use_delta_sync and start_page_token and connector.last_indexed_at - if can_use_delta_sync: + if can_use_delta: logger.info(f"Using delta sync for connector {connector_id}") - result = await _index_with_delta_sync( - drive_client=drive_client, - session=session, - connector=connector, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - folder_id=target_folder_id, - start_page_token=start_page_token, - task_logger=task_logger, - log_entry=log_entry, - max_files=max_files, - include_subfolders=include_subfolders, - on_heartbeat_callback=on_heartbeat_callback, - enable_summary=connector_enable_summary, + documents_indexed, documents_skipped = await _index_with_delta_sync( + drive_client, session, connector, connector_id, search_space_id, user_id, + target_folder_id, start_page_token, task_logger, log_entry, max_files, + include_subfolders, on_heartbeat_callback, connector_enable_summary, ) - documents_indexed, documents_skipped = result - - # Reconciliation: full scan re-indexes documents that were manually - # deleted from SurfSense but still exist in Google Drive. - # Already-indexed files are skipped via md5/modifiedTime checks, - # so the overhead is just one API listing call + fast DB lookups. logger.info("Running reconciliation scan after delta sync") - reconcile_result = await _index_full_scan( - drive_client=drive_client, - session=session, - connector=connector, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - folder_id=target_folder_id, - folder_name=target_folder_name, - task_logger=task_logger, - log_entry=log_entry, - max_files=max_files, - include_subfolders=include_subfolders, - on_heartbeat_callback=on_heartbeat_callback, - enable_summary=connector_enable_summary, + ri, rs = await _index_full_scan( + drive_client, session, connector, connector_id, search_space_id, user_id, + target_folder_id, target_folder_name, task_logger, log_entry, max_files, + include_subfolders, on_heartbeat_callback, connector_enable_summary, ) - documents_indexed += reconcile_result[0] - documents_skipped += reconcile_result[1] + documents_indexed += ri + documents_skipped += rs else: logger.info(f"Using full scan for connector {connector_id}") - result = await _index_full_scan( - drive_client=drive_client, - session=session, - connector=connector, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - folder_id=target_folder_id, - folder_name=target_folder_name, - task_logger=task_logger, - log_entry=log_entry, - max_files=max_files, - include_subfolders=include_subfolders, - on_heartbeat_callback=on_heartbeat_callback, - enable_summary=connector_enable_summary, + documents_indexed, documents_skipped = await _index_full_scan( + drive_client, session, connector, connector_id, search_space_id, user_id, + target_folder_id, target_folder_name, task_logger, log_entry, max_files, + include_subfolders, on_heartbeat_callback, connector_enable_summary, ) - documents_indexed, documents_skipped = result - if documents_indexed > 0 or can_use_delta_sync: + if documents_indexed > 0 or can_use_delta: new_token, token_error = await get_start_page_token(drive_client) if new_token and not token_error: - from sqlalchemy.orm.attributes import flag_modified - - # Refresh connector to reload attributes that may have been expired by earlier commits await session.refresh(connector) - if "folder_tokens" not in connector.config: connector.config["folder_tokens"] = {} connector.config["folder_tokens"][target_folder_id] = new_token flag_modified(connector, "config") - await update_connector_last_indexed(session, connector, update_last_indexed) await session.commit() - logger.info("Successfully committed Google Drive indexing changes to database") await task_logger.log_task_success( log_entry, f"Successfully completed Google Drive indexing for connector {connector_id}", { - "files_processed": documents_indexed, - "files_skipped": documents_skipped, - "sync_type": "delta" if can_use_delta_sync else "full", - "folder": target_folder_name, + "files_processed": documents_indexed, "files_skipped": documents_skipped, + "sync_type": "delta" if can_use_delta else "full", "folder": target_folder_name, }, ) - - logger.info( - f"Google Drive indexing completed: {documents_indexed} files indexed, {documents_skipped} skipped" - ) + logger.info(f"Google Drive indexing completed: {documents_indexed} indexed, {documents_skipped} skipped") return documents_indexed, documents_skipped, None except SQLAlchemyError as db_error: await session.rollback() await task_logger.log_task_failure( - log_entry, - f"Database error during Google Drive indexing for connector {connector_id}", - str(db_error), - {"error_type": "SQLAlchemyError"}, + log_entry, f"Database error during Google Drive indexing for connector {connector_id}", + str(db_error), {"error_type": "SQLAlchemyError"}, ) logger.error(f"Database error: {db_error!s}", exc_info=True) return 0, 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( - log_entry, - f"Failed to index Google Drive files for connector {connector_id}", - str(e), - {"error_type": type(e).__name__}, + log_entry, f"Failed to index Google Drive files for connector {connector_id}", + str(e), {"error_type": type(e).__name__}, ) logger.error(f"Failed to index Google Drive files: {e!s}", exc_info=True) return 0, 0, f"Failed to index Google Drive files: {e!s}" @@ -304,964 +608,81 @@ async def index_google_drive_single_file( file_id: str, file_name: str | None = None, ) -> tuple[int, str | None]: - """ - Index a single Google Drive file by its ID. - - Args: - session: Database session - connector_id: ID of the Drive connector - search_space_id: ID of the search space - user_id: ID of the user - file_id: Specific file ID to index - file_name: File name for display (optional) - - Returns: - Tuple of (number_of_indexed_files, error_message) - """ + """Index a single Google Drive file by its ID.""" task_logger = TaskLoggingService(session, search_space_id) - log_entry = await task_logger.log_task_start( task_name="google_drive_single_file_indexing", source="connector_indexing_task", message=f"Starting Google Drive single file indexing for file {file_id}", - metadata={ - "connector_id": connector_id, - "user_id": str(user_id), - "file_id": file_id, - "file_name": file_name, - }, + metadata={"connector_id": connector_id, "user_id": str(user_id), "file_id": file_id, "file_name": file_name}, ) try: - # Accept both native and Composio Drive connectors connector = None for ct in ACCEPTED_DRIVE_CONNECTOR_TYPES: connector = await get_connector_by_id(session, connector_id, ct) if connector: break - if not connector: error_msg = f"Google Drive connector with ID {connector_id} not found" - await task_logger.log_task_failure( - log_entry, error_msg, None, {"error_type": "ConnectorNotFound"} - ) + await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}) return 0, error_msg - await task_logger.log_task_progress( - log_entry, - f"Initializing Google Drive client for connector {connector_id}", - {"stage": "client_initialization"}, - ) - pre_built_credentials = None if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: error_msg = f"Composio connected_account_id not found for connector {connector_id}" - await task_logger.log_task_failure( - log_entry, - error_msg, - "Missing Composio account", - {"error_type": "MissingComposioAccount"}, - ) + await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"}) return 0, error_msg pre_built_credentials = build_composio_credentials(connected_account_id) else: token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted: - if not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}", - "Missing SECRET_KEY for token decryption", - {"error_type": "MissingSecretKey"}, - ) - return ( - 0, - "SECRET_KEY not configured but credentials are marked as encrypted", - ) - logger.info( - f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization" + if token_encrypted and not config.SECRET_KEY: + await task_logger.log_task_failure( + log_entry, "SECRET_KEY not configured but credentials are encrypted", + "Missing SECRET_KEY", {"error_type": "MissingSecretKey"}, ) + return 0, "SECRET_KEY not configured but credentials are marked as encrypted" connector_enable_summary = getattr(connector, "enable_summary", True) + drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials) - drive_client = GoogleDriveClient( - session, connector_id, credentials=pre_built_credentials - ) - - # Fetch the file metadata file, error = await get_file_by_id(drive_client, file_id) - if error or not file: error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}" - await task_logger.log_task_failure( - log_entry, error_msg, {"error_type": "FileNotFound"} - ) + await task_logger.log_task_failure(log_entry, error_msg, {"error_type": "FileNotFound"}) return 0, error_msg display_name = file_name or file.get("name", "Unknown") - logger.info(f"Indexing Google Drive file: {display_name} ({file_id})") - # Create pending document for status visibility - pending_doc, should_skip = await _create_pending_document_for_file( - session=session, - file=file, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - - if should_skip: - await task_logger.log_task_progress( - log_entry, - f"File {display_name} is unchanged or not indexable", - {"status": "skipped"}, - ) - return 0, None - - # Commit pending document so it appears in UI - if pending_doc and pending_doc.id is None: - await session.commit() - - # Process the file indexed, _skipped, failed = await _process_single_file( - drive_client=drive_client, - session=session, - file=file, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - task_logger=task_logger, - log_entry=log_entry, - pending_document=pending_doc, - enable_summary=connector_enable_summary, + drive_client, session, file, + connector_id, search_space_id, user_id, connector_enable_summary, ) - await session.commit() - logger.info( - "Successfully committed Google Drive file indexing changes to database" - ) if failed > 0: error_msg = f"Failed to index file {display_name}" - await task_logger.log_task_failure( - log_entry, - error_msg, - {"file_name": display_name, "file_id": file_id}, - ) + await task_logger.log_task_failure(log_entry, error_msg, {"file_name": display_name, "file_id": file_id}) return 0, error_msg if indexed > 0: await task_logger.log_task_success( - log_entry, - f"Successfully indexed file {display_name}", - { - "file_name": display_name, - "file_id": file_id, - }, + log_entry, f"Successfully indexed file {display_name}", + {"file_name": display_name, "file_id": file_id}, ) - logger.info(f"Google Drive file indexing completed: {display_name}") return 1, None - else: - await task_logger.log_task_progress( - log_entry, - f"File {display_name} was skipped", - {"status": "skipped"}, - ) - return 0, None + + return 0, None except SQLAlchemyError as db_error: await session.rollback() - await task_logger.log_task_failure( - log_entry, - "Database error during file indexing", - str(db_error), - {"error_type": "SQLAlchemyError"}, - ) + await task_logger.log_task_failure(log_entry, "Database error during file indexing", str(db_error), {"error_type": "SQLAlchemyError"}) logger.error(f"Database error: {db_error!s}", exc_info=True) return 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() - await task_logger.log_task_failure( - log_entry, - "Failed to index Google Drive file", - str(e), - {"error_type": type(e).__name__}, - ) + await task_logger.log_task_failure(log_entry, "Failed to index Google Drive file", str(e), {"error_type": type(e).__name__}) logger.error(f"Failed to index Google Drive file: {e!s}", exc_info=True) return 0, f"Failed to index Google Drive file: {e!s}" - - -async def _index_full_scan( - drive_client: GoogleDriveClient, - session: AsyncSession, - connector: any, - connector_id: int, - search_space_id: int, - user_id: str, - folder_id: str | None, - folder_name: str, - task_logger: TaskLoggingService, - log_entry: any, - max_files: int, - include_subfolders: bool = False, - on_heartbeat_callback: HeartbeatCallbackType | None = None, - enable_summary: bool = True, -) -> tuple[int, int]: - """Perform full scan indexing of a folder. - - Implements 2-phase document status updates for real-time UI feedback: - - Phase 1: Collect all files and create pending documents (visible in UI immediately) - - Phase 2: Process each file: pending → processing → ready/failed - """ - await task_logger.log_task_progress( - log_entry, - f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})", - { - "stage": "full_scan", - "folder_id": folder_id, - "include_subfolders": include_subfolders, - }, - ) - - documents_indexed = 0 - documents_skipped = 0 - documents_failed = 0 - files_processed = 0 - - # Heartbeat tracking - update notification periodically to prevent appearing stuck - last_heartbeat_time = time.time() - - # ======================================================================= - # PHASE 1: Collect all files and create pending documents - # This makes ALL documents visible in the UI immediately with pending status - # ======================================================================= - files_to_process = [] # List of (file, pending_document or None) - new_documents_created = False - - # Queue of folders to process: (folder_id, folder_name) - folders_to_process = [(folder_id, folder_name)] - first_listing_error: str | None = None - - logger.info("Phase 1: Collecting files and creating pending documents") - - while folders_to_process and files_processed < max_files: - current_folder_id, current_folder_name = folders_to_process.pop(0) - logger.info(f"Scanning folder: {current_folder_name} ({current_folder_id})") - page_token = None - - while files_processed < max_files: - # Get files and folders in current folder - files, next_token, error = await get_files_in_folder( - drive_client, - current_folder_id, - include_subfolders=True, - page_token=page_token, - ) - - if error: - logger.error(f"Error listing files in {current_folder_name}: {error}") - if first_listing_error is None: - first_listing_error = error - break - - if not files: - break - - for file in files: - if files_processed >= max_files: - break - - mime_type = file.get("mimeType", "") - - # If this is a folder and include_subfolders is enabled, queue it for processing - if mime_type == "application/vnd.google-apps.folder": - if include_subfolders: - folders_to_process.append( - (file["id"], file.get("name", "Unknown")) - ) - logger.debug(f"Queued subfolder: {file.get('name', 'Unknown')}") - continue - - files_processed += 1 - - # Create pending document for this file - pending_doc, should_skip = await _create_pending_document_for_file( - session=session, - file=file, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - - if should_skip: - documents_skipped += 1 - continue - - if pending_doc and pending_doc.id is None: - # New document was created - new_documents_created = True - - files_to_process.append((file, pending_doc)) - - page_token = next_token - if not page_token: - break - - if not files_to_process and first_listing_error: - error_lower = first_listing_error.lower() - if ( - "401" in first_listing_error - or "invalid credentials" in error_lower - or "authError" in first_listing_error - ): - raise Exception( - f"Google Drive authentication failed. Please re-authenticate. " - f"(Error: {first_listing_error})" - ) - raise Exception(f"Failed to list Google Drive files: {first_listing_error}") - - # Commit all pending documents - they all appear in UI now - if new_documents_created: - logger.info( - f"Phase 1: Committing {len([f for f in files_to_process if f[1] and f[1].id is None])} pending documents" - ) - await session.commit() - - # ======================================================================= - # PHASE 2: Process each file one by one - # Each document transitions: pending → processing → ready/failed - # ======================================================================= - logger.info(f"Phase 2: Processing {len(files_to_process)} files") - - for file, pending_doc in files_to_process: - # Check if it's time for a heartbeat update - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - indexed, skipped, failed = await _process_single_file( - drive_client=drive_client, - session=session, - file=file, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - task_logger=task_logger, - log_entry=log_entry, - pending_document=pending_doc, - enable_summary=enable_summary, - ) - - documents_indexed += indexed - documents_skipped += skipped - documents_failed += failed - - if documents_indexed % 10 == 0 and documents_indexed > 0: - await session.commit() - logger.info(f"Committed batch: {documents_indexed} files indexed so far") - - logger.info( - f"Full scan complete: {documents_indexed} indexed, {documents_skipped} skipped, {documents_failed} failed" - ) - return documents_indexed, documents_skipped - - -async def _index_with_delta_sync( - drive_client: GoogleDriveClient, - session: AsyncSession, - connector: any, - connector_id: int, - search_space_id: int, - user_id: str, - folder_id: str | None, - start_page_token: str, - task_logger: TaskLoggingService, - log_entry: any, - max_files: int, - include_subfolders: bool = False, - on_heartbeat_callback: HeartbeatCallbackType | None = None, - enable_summary: bool = True, -) -> tuple[int, int]: - """Perform delta sync indexing using change tracking. - - Note: include_subfolders is accepted for API consistency but delta sync - automatically tracks changes across all folders including subfolders. - - Implements 2-phase document status updates for real-time UI feedback: - - Phase 1: Collect all changes and create pending documents (visible in UI immediately) - - Phase 2: Process each file: pending → processing → ready/failed - """ - await task_logger.log_task_progress( - log_entry, - f"Starting delta sync from token: {start_page_token[:20]}...", - {"stage": "delta_sync", "start_token": start_page_token}, - ) - - changes, _final_token, error = await fetch_all_changes( - drive_client, start_page_token, folder_id - ) - - if error: - logger.error(f"Error fetching changes: {error}") - error_lower = error.lower() - if ( - "401" in error - or "invalid credentials" in error_lower - or "authError" in error - ): - raise Exception( - f"Google Drive authentication failed. Please re-authenticate. " - f"(Error: {error})" - ) - raise Exception(f"Failed to fetch Google Drive changes: {error}") - - if not changes: - logger.info("No changes detected since last sync") - return 0, 0 - - logger.info(f"Processing {len(changes)} changes") - - documents_indexed = 0 - documents_skipped = 0 - documents_failed = 0 - files_processed = 0 - - # Heartbeat tracking - update notification periodically to prevent appearing stuck - last_heartbeat_time = time.time() - - # ======================================================================= - # PHASE 1: Analyze changes and create pending documents for new/modified files - # ======================================================================= - changes_to_process = [] # List of (change, file, pending_document or None) - new_documents_created = False - - logger.info("Phase 1: Analyzing changes and creating pending documents") - - for change in changes: - if files_processed >= max_files: - break - - files_processed += 1 - change_type = categorize_change(change) - - if change_type in ["removed", "trashed"]: - file_id = change.get("fileId") - if file_id: - await _remove_document(session, file_id, search_space_id) - continue - - file = change.get("file") - if not file: - continue - - # Create pending document for this file - pending_doc, should_skip = await _create_pending_document_for_file( - session=session, - file=file, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - ) - - if should_skip: - documents_skipped += 1 - continue - - if pending_doc and pending_doc.id is None: - # New document was created - new_documents_created = True - - changes_to_process.append((change, file, pending_doc)) - - # Commit all pending documents - they all appear in UI now - if new_documents_created: - logger.info("Phase 1: Committing pending documents") - await session.commit() - - # ======================================================================= - # PHASE 2: Process each file one by one - # Each document transitions: pending → processing → ready/failed - # ======================================================================= - logger.info(f"Phase 2: Processing {len(changes_to_process)} changes") - - for _, file, pending_doc in changes_to_process: - # Check if it's time for a heartbeat update - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - indexed, skipped, failed = await _process_single_file( - drive_client=drive_client, - session=session, - file=file, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - task_logger=task_logger, - log_entry=log_entry, - pending_document=pending_doc, - enable_summary=enable_summary, - ) - - documents_indexed += indexed - documents_skipped += skipped - documents_failed += failed - - if documents_indexed % 10 == 0 and documents_indexed > 0: - await session.commit() - logger.info(f"Committed batch: {documents_indexed} changes processed") - - logger.info( - f"Delta sync complete: {documents_indexed} indexed, {documents_skipped} skipped, {documents_failed} failed" - ) - return documents_indexed, documents_skipped - - -async def _create_pending_document_for_file( - session: AsyncSession, - file: dict, - connector_id: int, - search_space_id: int, - user_id: str, -) -> tuple[Document | None, bool]: - """ - Create a pending document for a Google Drive file if it doesn't exist. - - This is Phase 1 of the 2-phase document status update pattern. - Creates documents with 'pending' status so they appear in UI immediately. - - Args: - session: Database session - file: File metadata from Google Drive API - connector_id: ID of the Drive connector - search_space_id: ID of the search space - user_id: ID of the user - - Returns: - Tuple of (document, should_skip): - - (existing_doc, False): Existing document that needs update - - (new_pending_doc, False): New pending document created - - (None, True): File should be skipped (unchanged, rename-only, or folder) - """ - from app.connectors.google_drive.file_types import should_skip_file - - file_id = file.get("id") - file_name = file.get("name", "Unknown") - mime_type = file.get("mimeType", "") - - # Skip folders and shortcuts - if should_skip_file(mime_type): - return None, True - - if not file_id: - return None, True - - # Generate unique identifier hash for this file - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id - ) - - # Check if document exists (primary hash first, then legacy Composio hash) - existing_document = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - if not existing_document: - legacy_hash = generate_unique_identifier_hash( - DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, file_id, search_space_id - ) - existing_document = await check_document_by_unique_identifier( - session, legacy_hash - ) - if existing_document: - existing_document.unique_identifier_hash = unique_identifier_hash - if ( - existing_document.document_type - == DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - existing_document.document_type = DocumentType.GOOGLE_DRIVE_FILE - logger.info(f"Migrated legacy Composio document to native type: {file_id}") - - if existing_document: - # Check if this is a rename-only update (content unchanged) - incoming_md5 = file.get("md5Checksum") - incoming_modified_time = file.get("modifiedTime") - doc_metadata = existing_document.document_metadata or {} - stored_md5 = doc_metadata.get("md5_checksum") - stored_modified_time = doc_metadata.get("modified_time") - - # Determine if content changed - content_unchanged = False - if incoming_md5 and stored_md5: - content_unchanged = incoming_md5 == stored_md5 - elif not incoming_md5 and incoming_modified_time and stored_modified_time: - # Google Workspace file - use modifiedTime as fallback - content_unchanged = incoming_modified_time == stored_modified_time - - if content_unchanged: - # Ensure status is ready (might have been stuck in processing/pending) - if not DocumentStatus.is_state( - existing_document.status, DocumentStatus.READY - ): - existing_document.status = DocumentStatus.ready() - return None, True - - # Content changed - return existing document for update - return existing_document, False - - # Create new pending document - document = Document( - search_space_id=search_space_id, - title=file_name, - document_type=DocumentType.GOOGLE_DRIVE_FILE, - document_metadata={ - "google_drive_file_id": file_id, - "google_drive_file_name": file_name, - "google_drive_mime_type": mime_type, - "connector_id": connector_id, - }, - content="Pending...", # Placeholder until processed - content_hash=unique_identifier_hash, # Temporary unique value - updated when ready - unique_identifier_hash=unique_identifier_hash, - embedding=None, - chunks=[], # Empty at creation - status=DocumentStatus.pending(), # Pending until processing starts - updated_at=get_current_timestamp(), - created_by_id=user_id, - connector_id=connector_id, - ) - session.add(document) - - return document, False - - -async def _check_rename_only_update( - session: AsyncSession, - file: dict, - search_space_id: int, -) -> tuple[bool, str | None]: - """ - Check if a file only needs a rename update (no content change). - - Uses md5Checksum comparison (preferred) or modifiedTime (fallback for Google Workspace files) - to detect if content has changed. This optimization prevents unnecessary ETL API calls - (Docling/LlamaCloud) for rename-only operations. - - Args: - session: Database session - file: File metadata from Google Drive API - search_space_id: ID of the search space - - Returns: - Tuple of (is_rename_only, message) - - (True, message): Only filename changed, document was updated - - (False, None): Content changed or new file, needs full processing - """ - from sqlalchemy import String, cast, select - from sqlalchemy.orm.attributes import flag_modified - - from app.db import Document - - file_id = file.get("id") - file_name = file.get("name", "Unknown") - incoming_md5 = file.get("md5Checksum") # None for Google Workspace files - incoming_modified_time = file.get("modifiedTime") - - if not file_id: - return False, None - - # Try to find existing document by file_id-based hash (primary method) - primary_hash = generate_unique_identifier_hash( - DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id - ) - existing_document = await check_document_by_unique_identifier(session, primary_hash) - - # Fallback: legacy Composio hash - if not existing_document: - legacy_hash = generate_unique_identifier_hash( - DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, file_id, search_space_id - ) - existing_document = await check_document_by_unique_identifier( - session, legacy_hash - ) - - # Fallback: metadata search (covers old filename-based hashes) - if not existing_document: - result = await session.execute( - select(Document).where( - Document.search_space_id == search_space_id, - Document.document_type.in_( - [ - DocumentType.GOOGLE_DRIVE_FILE, - DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - ), - cast(Document.document_metadata["google_drive_file_id"], String) - == file_id, - ) - ) - existing_document = result.scalar_one_or_none() - if existing_document: - logger.debug(f"Found legacy document by metadata for file_id: {file_id}") - - # Migrate legacy Composio document to native type - if existing_document: - if existing_document.unique_identifier_hash != primary_hash: - existing_document.unique_identifier_hash = primary_hash - if ( - existing_document.document_type - == DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR - ): - existing_document.document_type = DocumentType.GOOGLE_DRIVE_FILE - logger.info(f"Migrated legacy Composio Drive document: {file_id}") - - if not existing_document: - # New file, needs full processing - return False, None - - # Get stored checksums/timestamps from document metadata - doc_metadata = existing_document.document_metadata or {} - stored_md5 = doc_metadata.get("md5_checksum") - stored_modified_time = doc_metadata.get("modified_time") - - # Determine if content changed using md5Checksum (preferred) or modifiedTime (fallback) - content_unchanged = False - - if incoming_md5 and stored_md5: - # Best case: Compare md5 checksums (only changes when content changes, not on rename) - content_unchanged = incoming_md5 == stored_md5 - logger.debug(f"MD5 comparison for {file_name}: unchanged={content_unchanged}") - elif incoming_md5 and not stored_md5: - # Have incoming md5 but no stored md5 (legacy doc) - need to reprocess to store it - logger.debug( - f"No stored md5 for {file_name}, will reprocess to store md5_checksum" - ) - return False, None - elif not incoming_md5: - # Google Workspace file (no md5Checksum available) - fall back to modifiedTime - # Note: modifiedTime is less reliable as it changes on rename too, but it's the best we have - if incoming_modified_time and stored_modified_time: - content_unchanged = incoming_modified_time == stored_modified_time - logger.debug( - f"ModifiedTime fallback for Google Workspace file {file_name}: unchanged={content_unchanged}" - ) - else: - # No stored modifiedTime (legacy) - reprocess to store it - return False, None - - if content_unchanged: - # Content hasn't changed - check if filename changed - old_name = doc_metadata.get("FILE_NAME") or doc_metadata.get( - "google_drive_file_name" - ) - - if old_name and old_name != file_name: - # Rename-only update - update the document without re-processing - existing_document.title = file_name - if not existing_document.document_metadata: - existing_document.document_metadata = {} - existing_document.document_metadata["FILE_NAME"] = file_name - existing_document.document_metadata["google_drive_file_name"] = file_name - # Also update modified_time for Google Workspace files (since it changed on rename) - if incoming_modified_time: - existing_document.document_metadata["modified_time"] = ( - incoming_modified_time - ) - flag_modified(existing_document, "document_metadata") - await session.commit() - - logger.info( - f"Rename-only update: '{old_name}' → '{file_name}' (skipped ETL)" - ) - return ( - True, - f"File renamed: '{old_name}' → '{file_name}' (no content change)", - ) - else: - # Neither content nor name changed - logger.debug(f"File unchanged: {file_name}") - return True, "File unchanged (same content and name)" - - # Content changed - needs full processing - return False, None - - -async def _process_single_file( - drive_client: GoogleDriveClient, - session: AsyncSession, - file: dict, - connector_id: int, - search_space_id: int, - user_id: str, - task_logger: TaskLoggingService, - log_entry: any, - pending_document: Document | None = None, - enable_summary: bool = True, -) -> tuple[int, int, int]: - """ - Process a single file by downloading and using Surfsense's file processor. - - Implements Phase 2 of the 2-phase document status update pattern. - Updates document status: pending → processing → ready/failed - - Args: - drive_client: Google Drive client - session: Database session - file: File metadata from Google Drive API - connector_id: ID of the connector - search_space_id: ID of the search space - user_id: ID of the user - task_logger: Task logging service - log_entry: Log entry for tracking - pending_document: Optional pending document created in Phase 1 - - Returns: - Tuple of (indexed_count, skipped_count, failed_count) - """ - file_name = file.get("name", "Unknown") - mime_type = file.get("mimeType", "") - file_id = file.get("id") - - try: - logger.info(f"Processing file: {file_name} ({mime_type})") - - # Early check: Is this a rename-only update? - # This optimization prevents downloading and ETL processing for files - # where only the name changed but content is the same. - is_rename_only, rename_message = await _check_rename_only_update( - session=session, - file=file, - search_space_id=search_space_id, - ) - - if is_rename_only: - await task_logger.log_task_progress( - log_entry, - f"Skipped ETL for {file_name}: {rename_message}", - {"status": "rename_only", "reason": rename_message}, - ) - # Return 1 for renamed files (they are "indexed" in the sense that they're updated) - # Return 0 for unchanged files - if "renamed" in (rename_message or "").lower(): - return 1, 0, 0 - return 0, 1, 0 - - # Set document to PROCESSING status if we have a pending document - if pending_document: - pending_document.status = DocumentStatus.processing() - await session.commit() - - _, error, _metadata = await download_and_process_file( - client=drive_client, - file=file, - search_space_id=search_space_id, - user_id=user_id, - session=session, - task_logger=task_logger, - log_entry=log_entry, - connector_id=connector_id, - enable_summary=enable_summary, - ) - - if error: - await task_logger.log_task_progress( - log_entry, - f"Skipped {file_name}: {error}", - {"status": "skipped", "reason": error}, - ) - # Mark pending document as failed if it exists - if pending_document: - pending_document.status = DocumentStatus.failed(error) - pending_document.updated_at = get_current_timestamp() - await session.commit() - return 0, 1, 0 - - # The document was created/updated by download_and_process_file - # Find the document and ensure it has READY status - if file_id: - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id - ) - processed_doc = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - # Ensure status is READY - if processed_doc and not DocumentStatus.is_state( - processed_doc.status, DocumentStatus.READY - ): - processed_doc.status = DocumentStatus.ready() - processed_doc.updated_at = get_current_timestamp() - await session.commit() - - logger.info(f"Successfully indexed Google Drive file: {file_name}") - return 1, 0, 0 - - except Exception as e: - logger.error(f"Error processing file {file_name}: {e!s}", exc_info=True) - # Mark pending document as failed if it exists - if pending_document: - try: - pending_document.status = DocumentStatus.failed(str(e)) - pending_document.updated_at = get_current_timestamp() - await session.commit() - except Exception as status_error: - logger.error( - f"Failed to update document status to failed: {status_error}" - ) - return 0, 0, 1 - - -async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int): - """Remove a document that was deleted in Drive. - - Handles both new (file_id-based) and legacy (filename-based) hash schemes. - """ - from sqlalchemy import String, cast, select - - from app.db import Document - - # First try with file_id-based hash (new method) - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id - ) - - existing_document = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - - # Fallback: legacy Composio hash - if not existing_document: - legacy_hash = generate_unique_identifier_hash( - DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, file_id, search_space_id - ) - existing_document = await check_document_by_unique_identifier( - session, legacy_hash - ) - - # Fallback: metadata search (covers old filename-based hashes, both native and Composio) - if not existing_document: - result = await session.execute( - select(Document).where( - Document.search_space_id == search_space_id, - Document.document_type.in_( - [ - DocumentType.GOOGLE_DRIVE_FILE, - DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, - ] - ), - cast(Document.document_metadata["google_drive_file_id"], String) - == file_id, - ) - ) - existing_document = result.scalar_one_or_none() - if existing_document: - logger.info(f"Found legacy document by metadata for file_id: {file_id}") - - if existing_document: - await session.delete(existing_document) - logger.info(f"Removed deleted file document: {file_id}") diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index 384ad85e2..96cc1cbb4 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -1,11 +1,11 @@ """ Google Gmail connector indexer. -Implements 2-phase document status updates for real-time UI feedback: -- Phase 1: Create all documents with 'pending' status (visible in UI immediately) -- Phase 2: Process each document: pending → processing → ready/failed +Uses the shared IndexingPipelineService for document deduplication, +summarization, chunking, and embedding. """ +import logging import time from collections.abc import Awaitable, Callable from datetime import datetime @@ -15,21 +15,15 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.google_gmail_connector import GoogleGmailConnector -from app.db import ( - Document, - DocumentStatus, - DocumentType, - SearchSourceConnectorType, +from app.db import DocumentType, SearchSourceConnectorType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import ( + compute_content_hash, + compute_unique_identifier_hash, ) +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.document_converters import ( - create_document_chunks, - embed_text, - generate_content_hash, - generate_document_summary, - generate_unique_identifier_hash, -) from app.utils.google_credentials import ( COMPOSIO_GOOGLE_CONNECTOR_TYPES, build_composio_credentials, @@ -37,12 +31,9 @@ from app.utils.google_credentials import ( from .base import ( calculate_date_range, - check_document_by_unique_identifier, check_duplicate_document_by_hash, get_connector_by_id, - get_current_timestamp, logger, - safe_set_chunks, update_connector_last_indexed, ) @@ -51,13 +42,70 @@ ACCEPTED_GMAIL_CONNECTOR_TYPES = { SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, } -# Type hint for heartbeat callback HeartbeatCallbackType = Callable[[int], Awaitable[None]] - -# Heartbeat interval in seconds HEARTBEAT_INTERVAL_SECONDS = 30 +def _build_connector_doc( + message: dict, + markdown_content: str, + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, +) -> ConnectorDocument: + """Map a raw Gmail API message dict to a ConnectorDocument.""" + message_id = message.get("id", "") + thread_id = message.get("threadId", "") + payload = message.get("payload", {}) + headers = payload.get("headers", []) + + subject = "No Subject" + sender = "Unknown Sender" + date_str = "Unknown Date" + + for header in headers: + name = header.get("name", "").lower() + value = header.get("value", "") + if name == "subject": + subject = value + elif name == "from": + sender = value + elif name == "date": + date_str = value + + metadata = { + "message_id": message_id, + "thread_id": thread_id, + "subject": subject, + "sender": sender, + "date": date_str, + "connector_id": connector_id, + "document_type": "Gmail Message", + "connector_type": "Google Gmail", + } + + fallback_summary = ( + f"Google Gmail Message: {subject}\n\n" + f"From: {sender}\nDate: {date_str}\n\n" + f"{markdown_content}" + ) + + return ConnectorDocument( + title=subject, + source_markdown=markdown_content, + unique_id=message_id, + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=enable_summary, + fallback_summary=fallback_summary, + metadata=metadata, + ) + + async def index_google_gmail_messages( session: AsyncSession, connector_id: int, @@ -80,7 +128,7 @@ async def index_google_gmail_messages( start_date: Start date for filtering messages (YYYY-MM-DD format) end_date: End date for filtering messages (YYYY-MM-DD format) update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - max_messages: Maximum number of messages to fetch (default: 100) + max_messages: Maximum number of messages to fetch (default: 1000) on_heartbeat_callback: Optional callback to update notification during long-running indexing. Returns: @@ -88,7 +136,6 @@ async def index_google_gmail_messages( """ task_logger = TaskLoggingService(session, search_space_id) - # Log task start log_entry = await task_logger.log_task_start( task_name="google_gmail_messages_indexing", source="connector_indexing_task", @@ -103,7 +150,7 @@ async def index_google_gmail_messages( ) try: - # Accept both native and Composio Gmail connectors + # ── Connector lookup ────────────────────────────────────────── connector = None for ct in ACCEPTED_GMAIL_CONNECTOR_TYPES: connector = await get_connector_by_id(session, connector_id, ct) @@ -117,7 +164,7 @@ async def index_google_gmail_messages( ) return 0, 0, error_msg - # Build credentials based on connector type + # ── Credential building ─────────────────────────────────────── if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: connected_account_id = connector.config.get("composio_connected_account_id") if not connected_account_id: @@ -189,6 +236,7 @@ async def index_google_gmail_messages( ) return 0, 0, "Google gmail credentials not found in connector config" + # ── Gmail client init ───────────────────────────────────────── await task_logger.log_task_progress( log_entry, f"Initializing Google gmail client for connector {connector_id}", @@ -199,14 +247,11 @@ async def index_google_gmail_messages( credentials, session, user_id, connector_id ) - # Calculate date range using last_indexed_at if dates not provided - # This ensures Gmail uses the same date logic as other connectors - # (uses last_indexed_at → now, or 365 days back for first-time indexing) calculated_start_date, calculated_end_date = calculate_date_range( connector, start_date, end_date, default_days_back=365 ) - # Fetch recent Google gmail messages + # ── Fetch messages ──────────────────────────────────────────── logger.info( f"Fetching emails for connector {connector_id} " f"from {calculated_start_date} to {calculated_end_date}" @@ -218,7 +263,6 @@ async def index_google_gmail_messages( ) if error: - # Check if this is an authentication error that requires re-authentication error_message = error error_type = "APIError" if ( @@ -243,263 +287,92 @@ async def index_google_gmail_messages( logger.info(f"Found {len(messages)} Google gmail messages to index") - documents_indexed = 0 + # ── Build ConnectorDocuments ────────────────────────────────── + connector_docs: list[ConnectorDocument] = [] documents_skipped = 0 - documents_failed = 0 # Track messages that failed processing - duplicate_content_count = ( - 0 # Track messages skipped due to duplicate content_hash - ) - - # Heartbeat tracking - update notification periodically to prevent appearing stuck - last_heartbeat_time = time.time() - - # ======================================================================= - # PHASE 1: Analyze all messages, create pending documents - # This makes ALL documents visible in the UI immediately with pending status - # ======================================================================= - messages_to_process = [] # List of dicts with document and message data - new_documents_created = False + duplicate_content_count = 0 for message in messages: try: - # Extract message information message_id = message.get("id", "") - thread_id = message.get("threadId", "") - - # Extract headers for subject and sender - payload = message.get("payload", {}) - headers = payload.get("headers", []) - - subject = "No Subject" - sender = "Unknown Sender" - date_str = "Unknown Date" - - for header in headers: - name = header.get("name", "").lower() - value = header.get("value", "") - if name == "subject": - subject = value - elif name == "from": - sender = value - elif name == "date": - date_str = value - if not message_id: - logger.warning(f"Skipping message with missing ID: {subject}") + logger.warning("Skipping message with missing ID") documents_skipped += 1 continue - # Format message to markdown markdown_content = gmail_connector.format_message_to_markdown(message) - if not markdown_content.strip(): - logger.warning(f"Skipping message with no content: {subject}") + logger.warning(f"Skipping message with no content: {message_id}") documents_skipped += 1 continue - # Generate unique identifier hash for this Gmail message - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.GOOGLE_GMAIL_CONNECTOR, message_id, search_space_id + doc = _build_connector_doc( + message, + markdown_content, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=connector.enable_summary, ) - # Generate content hash - content_hash = generate_content_hash(markdown_content, search_space_id) - - # Check if document with this unique identifier already exists - existing_document = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - - # Fallback: legacy Composio hash - if not existing_document: - legacy_hash = generate_unique_identifier_hash( - DocumentType.COMPOSIO_GMAIL_CONNECTOR, - message_id, - search_space_id, - ) - existing_document = await check_document_by_unique_identifier( - session, legacy_hash - ) - if existing_document: - existing_document.unique_identifier_hash = ( - unique_identifier_hash - ) - if ( - existing_document.document_type - == DocumentType.COMPOSIO_GMAIL_CONNECTOR - ): - existing_document.document_type = ( - DocumentType.GOOGLE_GMAIL_CONNECTOR - ) - logger.info( - f"Migrated legacy Composio Gmail document: {message_id}" - ) - - if existing_document: - # Document exists - check if content has changed - if existing_document.content_hash == content_hash: - # Ensure status is ready (might have been stuck in processing/pending) - if not DocumentStatus.is_state( - existing_document.status, DocumentStatus.READY - ): - existing_document.status = DocumentStatus.ready() - documents_skipped += 1 - continue - - # Queue existing document for update (will be set to processing in Phase 2) - messages_to_process.append( - { - "document": existing_document, - "is_new": False, - "markdown_content": markdown_content, - "content_hash": content_hash, - "message_id": message_id, - "thread_id": thread_id, - "subject": subject, - "sender": sender, - "date_str": date_str, - } - ) - continue - - # Document doesn't exist by unique_identifier_hash - # Check if a document with the same content_hash exists (from another connector) with session.no_autoflush: - duplicate_by_content = await check_duplicate_document_by_hash( - session, content_hash + duplicate = await check_duplicate_document_by_hash( + session, compute_content_hash(doc) ) - - if duplicate_by_content: + if duplicate: logger.info( - f"Gmail message {subject} already indexed by another connector " - f"(existing document ID: {duplicate_by_content.id}, " - f"type: {duplicate_by_content.document_type}). Skipping." + f"Gmail message {doc.title} already indexed by another connector " + f"(existing document ID: {duplicate.id}, " + f"type: {duplicate.document_type}). Skipping." ) duplicate_content_count += 1 documents_skipped += 1 continue - # Create new document with PENDING status (visible in UI immediately) - document = Document( - search_space_id=search_space_id, - title=subject, - document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, - document_metadata={ - "message_id": message_id, - "thread_id": thread_id, - "subject": subject, - "sender": sender, - "date": date_str, - "connector_id": connector_id, - }, - content="Pending...", # Placeholder until processed - content_hash=unique_identifier_hash, # Temporary unique value - updated when ready - unique_identifier_hash=unique_identifier_hash, - embedding=None, - chunks=[], # Empty at creation - safe for async - status=DocumentStatus.pending(), # Pending until processing starts - updated_at=get_current_timestamp(), - created_by_id=user_id, - connector_id=connector_id, - ) - session.add(document) - new_documents_created = True - - messages_to_process.append( - { - "document": document, - "is_new": True, - "markdown_content": markdown_content, - "content_hash": content_hash, - "message_id": message_id, - "thread_id": thread_id, - "subject": subject, - "sender": sender, - "date_str": date_str, - } - ) + connector_docs.append(doc) except Exception as e: - logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True) - documents_failed += 1 + logger.error(f"Error building ConnectorDocument for message: {e!s}", exc_info=True) + documents_skipped += 1 continue - # Commit all pending documents - they all appear in UI now - if new_documents_created: - logger.info( - f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents" - ) - await session.commit() + # ── Pipeline: migrate legacy docs + prepare + index ─────────── + pipeline = IndexingPipelineService(session) - # ======================================================================= - # PHASE 2: Process each document one by one - # Each document transitions: pending → processing → ready/failed - # ======================================================================= - logger.info(f"Phase 2: Processing {len(messages_to_process)} documents") + await pipeline.migrate_legacy_docs(connector_docs) - for item in messages_to_process: - # Send heartbeat periodically + documents = await pipeline.prepare_for_indexing(connector_docs) + + doc_map = { + compute_unique_identifier_hash(cd): cd for cd in connector_docs + } + + documents_indexed = 0 + documents_failed = 0 + last_heartbeat_time = time.time() + + for document in documents: if on_heartbeat_callback: current_time = time.time() if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: await on_heartbeat_callback(documents_indexed) last_heartbeat_time = current_time - document = item["document"] - try: - # Set to PROCESSING and commit - shows "processing" in UI for THIS document only - document.status = DocumentStatus.processing() - await session.commit() + connector_doc = doc_map.get(document.unique_identifier_hash) + if connector_doc is None: + logger.warning( + f"No matching ConnectorDocument for document {document.id}, skipping" + ) + documents_failed += 1 + continue - # Heavy processing (LLM, embeddings, chunks) + try: user_llm = await get_user_long_context_llm( session, user_id, search_space_id ) - - if user_llm and connector.enable_summary: - document_metadata_for_summary = { - "message_id": item["message_id"], - "thread_id": item["thread_id"], - "subject": item["subject"], - "sender": item["sender"], - "date": item["date_str"], - "document_type": "Gmail Message", - "connector_type": "Google Gmail", - } - ( - summary_content, - summary_embedding, - ) = await generate_document_summary( - item["markdown_content"], - user_llm, - document_metadata_for_summary, - ) - else: - summary_content = f"Google Gmail Message: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}" - summary_embedding = embed_text(summary_content) - - chunks = await create_document_chunks(item["markdown_content"]) - - # Update document to READY with actual content - document.title = item["subject"] - document.content = summary_content - document.content_hash = item["content_hash"] - document.embedding = summary_embedding - document.document_metadata = { - "message_id": item["message_id"], - "thread_id": item["thread_id"], - "subject": item["subject"], - "sender": item["sender"], - "date": item["date_str"], - "connector_id": connector_id, - } - await safe_set_chunks(session, document, chunks) - document.updated_at = get_current_timestamp() - document.status = DocumentStatus.ready() - + await pipeline.index(document, connector_doc, user_llm) documents_indexed += 1 - # Batch commit every 10 documents (for ready status updates) if documents_indexed % 10 == 0: logger.info( f"Committing batch: {documents_indexed} Gmail messages processed so far" @@ -508,21 +381,12 @@ async def index_google_gmail_messages( except Exception as e: logger.error(f"Error processing Gmail message: {e!s}", exc_info=True) - # Mark document as failed with reason (visible in UI) - try: - document.status = DocumentStatus.failed(str(e)) - document.updated_at = get_current_timestamp() - except Exception as status_error: - logger.error( - f"Failed to update document status to failed: {status_error}" - ) documents_failed += 1 continue - # CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs + # ── Finalize ────────────────────────────────────────────────── await update_connector_last_indexed(session, connector, update_last_indexed) - # Final commit for any remaining documents not yet committed in batches logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed") try: await session.commit() @@ -530,22 +394,18 @@ async def index_google_gmail_messages( "Successfully committed all Google Gmail document changes to database" ) except Exception as e: - # Handle any remaining integrity errors gracefully (race conditions, etc.) if ( "duplicate key value violates unique constraint" in str(e).lower() or "uniqueviolationerror" in str(e).lower() ): logger.warning( f"Duplicate content_hash detected during final commit. " - f"This may occur if the same message was indexed by multiple connectors. " f"Rolling back and continuing. Error: {e!s}" ) await session.rollback() - # Don't fail the entire task - some documents may have been successfully indexed else: raise - # Build warning message if there were issues warning_parts = [] if duplicate_content_count > 0: warning_parts.append(f"{duplicate_content_count} duplicate") @@ -555,7 +415,6 @@ async def index_google_gmail_messages( total_processed = documents_indexed - # Log success await task_logger.log_task_success( log_entry, f"Successfully completed Google Gmail indexing for connector {connector_id}", diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_document_hashing.py b/surfsense_backend/tests/unit/indexing_pipeline/test_document_hashing.py index fe536b066..d04d8b048 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_document_hashing.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_document_hashing.py @@ -3,6 +3,7 @@ import pytest from app.db import DocumentType from app.indexing_pipeline.document_hashing import ( compute_content_hash, + compute_identifier_hash, compute_unique_identifier_hash, ) @@ -61,3 +62,23 @@ def test_different_content_produces_different_content_hash(make_connector_docume doc_a = make_connector_document(source_markdown="Original content") doc_b = make_connector_document(source_markdown="Updated content") assert compute_content_hash(doc_a) != compute_content_hash(doc_b) + + +def test_compute_identifier_hash_matches_connector_doc_hash(make_connector_document): + """Raw-args hash equals ConnectorDocument hash for equivalent inputs.""" + doc = make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-123", + search_space_id=5, + ) + raw_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-123", 5) + assert raw_hash == compute_unique_identifier_hash(doc) + + +def test_compute_identifier_hash_differs_for_different_inputs(): + """Different arguments produce different hashes.""" + h1 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 1) + h2 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-2", 1) + h3 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 2) + h4 = compute_identifier_hash("COMPOSIO_GOOGLE_DRIVE_CONNECTOR", "file-1", 1) + assert len({h1, h2, h3, h4}) == 4 From 8c41fd91bafc01347e4ab35da3cd79c9c4e7b104 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:34:02 +0530 Subject: [PATCH 02/31] feat: add integration tests for indexing pipeline components - Introduced integration tests for Calendar, Drive, and Gmail indexers to ensure proper document creation and migration. - Added tests for batch indexing functionality to validate the processing of multiple documents. - Implemented tests for legacy document migration to verify updates to document types and hashes. - Enhanced test coverage for the IndexingPipelineService to ensure robust functionality across various document types. --- .../test_calendar_pipeline.py | 111 +++++++++++++++ .../indexing_pipeline/test_drive_pipeline.py | 110 +++++++++++++++ .../indexing_pipeline/test_gmail_pipeline.py | 116 ++++++++++++++++ .../indexing_pipeline/test_index_batch.py | 55 ++++++++ .../test_migrate_legacy_docs.py | 92 +++++++++++++ .../indexing_pipeline/test_index_batch.py | 82 +++++++++++ .../test_migrate_legacy_docs.py | 127 ++++++++++++++++++ 7 files changed, 693 insertions(+) create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/test_calendar_pipeline.py create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/test_gmail_pipeline.py create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/test_index_batch.py create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/test_migrate_legacy_docs.py create mode 100644 surfsense_backend/tests/unit/indexing_pipeline/test_index_batch.py create mode 100644 surfsense_backend/tests/unit/indexing_pipeline/test_migrate_legacy_docs.py diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_calendar_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_calendar_pipeline.py new file mode 100644 index 000000000..6a60c5cc1 --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_calendar_pipeline.py @@ -0,0 +1,111 @@ +"""Integration tests: Calendar indexer builds ConnectorDocuments that flow through the pipeline.""" + +import pytest +from sqlalchemy import select + +from app.config import config as app_config +from app.db import Document, DocumentStatus, DocumentType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_identifier_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +_EMBEDDING_DIM = app_config.embedding_model_instance.dimension + +pytestmark = pytest.mark.integration + + +def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument: + return ConnectorDocument( + title=f"Event {unique_id}", + source_markdown=f"## Calendar Event\n\nDetails for {unique_id}", + unique_id=unique_id, + document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=True, + fallback_summary=f"Calendar: Event {unique_id}", + metadata={ + "event_id": unique_id, + "start_time": "2025-01-15T10:00:00", + "end_time": "2025-01-15T11:00:00", + "document_type": "Google Calendar Event", + }, + ) + + +@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +async def test_calendar_pipeline_creates_ready_document( + db_session, db_search_space, db_connector, db_user, mocker +): + """A Calendar ConnectorDocument flows through prepare + index to a READY document.""" + space_id = db_search_space.id + doc = _cal_doc( + unique_id="evt-1", + search_space_id=space_id, + connector_id=db_connector.id, + user_id=str(db_user.id), + ) + + service = IndexingPipelineService(session=db_session) + prepared = await service.prepare_for_indexing([doc]) + assert len(prepared) == 1 + + await service.index(prepared[0], doc, llm=mocker.Mock()) + + result = await db_session.execute( + select(Document).filter(Document.search_space_id == space_id) + ) + row = result.scalars().first() + + assert row is not None + assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR + assert DocumentStatus.is_state(row.status, DocumentStatus.READY) + + +@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +async def test_calendar_legacy_doc_migrated( + db_session, db_search_space, db_connector, db_user, mocker +): + """A legacy Composio Calendar doc is migrated and reused.""" + space_id = db_search_space.id + user_id = str(db_user.id) + evt_id = "evt-legacy-cal" + + legacy_hash = compute_identifier_hash( + DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id + ) + legacy_doc = Document( + title="Old Calendar Event", + document_type=DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + content="old summary", + content_hash=f"ch-{legacy_hash[:12]}", + unique_identifier_hash=legacy_hash, + source_markdown="## Old event", + search_space_id=space_id, + created_by_id=user_id, + embedding=[0.1] * _EMBEDDING_DIM, + status={"state": "ready"}, + ) + db_session.add(legacy_doc) + await db_session.flush() + original_id = legacy_doc.id + + connector_doc = _cal_doc( + unique_id=evt_id, + search_space_id=space_id, + connector_id=db_connector.id, + user_id=user_id, + ) + + service = IndexingPipelineService(session=db_session) + await service.migrate_legacy_docs([connector_doc]) + + result = await db_session.execute(select(Document).filter(Document.id == original_id)) + row = result.scalars().first() + + assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR + native_hash = compute_identifier_hash( + DocumentType.GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id + ) + assert row.unique_identifier_hash == native_hash diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py new file mode 100644 index 000000000..32af0b8c1 --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py @@ -0,0 +1,110 @@ +"""Integration tests: Drive indexer builds ConnectorDocuments that flow through the pipeline.""" + +import pytest +from sqlalchemy import select + +from app.config import config as app_config +from app.db import Document, DocumentStatus, DocumentType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_identifier_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +_EMBEDDING_DIM = app_config.embedding_model_instance.dimension + +pytestmark = pytest.mark.integration + + +def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument: + return ConnectorDocument( + title=f"File {unique_id}.pdf", + source_markdown=f"## Document Content\n\nText from file {unique_id}", + unique_id=unique_id, + document_type=DocumentType.GOOGLE_DRIVE_FILE, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=True, + fallback_summary=f"File: {unique_id}.pdf", + metadata={ + "google_drive_file_id": unique_id, + "google_drive_file_name": f"{unique_id}.pdf", + "document_type": "Google Drive File", + }, + ) + + +@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +async def test_drive_pipeline_creates_ready_document( + db_session, db_search_space, db_connector, db_user, mocker +): + """A Drive ConnectorDocument flows through prepare + index to a READY document.""" + space_id = db_search_space.id + doc = _drive_doc( + unique_id="file-abc", + search_space_id=space_id, + connector_id=db_connector.id, + user_id=str(db_user.id), + ) + + service = IndexingPipelineService(session=db_session) + prepared = await service.prepare_for_indexing([doc]) + assert len(prepared) == 1 + + await service.index(prepared[0], doc, llm=mocker.Mock()) + + result = await db_session.execute( + select(Document).filter(Document.search_space_id == space_id) + ) + row = result.scalars().first() + + assert row is not None + assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE + assert DocumentStatus.is_state(row.status, DocumentStatus.READY) + + +@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +async def test_drive_legacy_doc_migrated( + db_session, db_search_space, db_connector, db_user, mocker +): + """A legacy Composio Drive doc is migrated and reused.""" + space_id = db_search_space.id + user_id = str(db_user.id) + file_id = "file-legacy-drive" + + legacy_hash = compute_identifier_hash( + DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR.value, file_id, space_id + ) + legacy_doc = Document( + title="Old Drive File", + document_type=DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, + content="old file summary", + content_hash=f"ch-{legacy_hash[:12]}", + unique_identifier_hash=legacy_hash, + source_markdown="## Old file content", + search_space_id=space_id, + created_by_id=user_id, + embedding=[0.1] * _EMBEDDING_DIM, + status={"state": "ready"}, + ) + db_session.add(legacy_doc) + await db_session.flush() + original_id = legacy_doc.id + + connector_doc = _drive_doc( + unique_id=file_id, + search_space_id=space_id, + connector_id=db_connector.id, + user_id=user_id, + ) + + service = IndexingPipelineService(session=db_session) + await service.migrate_legacy_docs([connector_doc]) + + result = await db_session.execute(select(Document).filter(Document.id == original_id)) + row = result.scalars().first() + + assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE + native_hash = compute_identifier_hash( + DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id + ) + assert row.unique_identifier_hash == native_hash diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_gmail_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_gmail_pipeline.py new file mode 100644 index 000000000..d67420cb7 --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_gmail_pipeline.py @@ -0,0 +1,116 @@ +"""Integration tests: Gmail indexer builds ConnectorDocuments that flow through the pipeline.""" + +import pytest +from sqlalchemy import select + +from app.config import config as app_config +from app.db import Document, DocumentStatus, DocumentType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import ( + compute_identifier_hash, + compute_unique_identifier_hash, +) +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +_EMBEDDING_DIM = app_config.embedding_model_instance.dimension + +pytestmark = pytest.mark.integration + + +def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument: + """Build a Gmail-style ConnectorDocument like the real indexer does.""" + return ConnectorDocument( + title=f"Subject for {unique_id}", + source_markdown=f"## Email\n\nBody of {unique_id}", + unique_id=unique_id, + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=True, + fallback_summary=f"Gmail: Subject for {unique_id}", + metadata={ + "message_id": unique_id, + "from": "sender@example.com", + "document_type": "Gmail Message", + }, + ) + + +@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +async def test_gmail_pipeline_creates_ready_document( + db_session, db_search_space, db_connector, db_user, mocker +): + """A Gmail ConnectorDocument flows through prepare + index to a READY document.""" + space_id = db_search_space.id + doc = _gmail_doc( + unique_id="msg-pipeline-1", + search_space_id=space_id, + connector_id=db_connector.id, + user_id=str(db_user.id), + ) + + service = IndexingPipelineService(session=db_session) + prepared = await service.prepare_for_indexing([doc]) + assert len(prepared) == 1 + + await service.index(prepared[0], doc, llm=mocker.Mock()) + + result = await db_session.execute( + select(Document).filter(Document.search_space_id == space_id) + ) + row = result.scalars().first() + + assert row is not None + assert row.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR + assert DocumentStatus.is_state(row.status, DocumentStatus.READY) + assert row.source_markdown == doc.source_markdown + + +@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +async def test_gmail_legacy_doc_migrated_then_reused( + db_session, db_search_space, db_connector, db_user, mocker +): + """A legacy Composio Gmail doc is migrated then reused by the pipeline.""" + space_id = db_search_space.id + user_id = str(db_user.id) + msg_id = "msg-legacy-gmail" + + legacy_hash = compute_identifier_hash( + DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, msg_id, space_id + ) + legacy_doc = Document( + title="Old Gmail", + document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR, + content="old summary", + content_hash=f"ch-{legacy_hash[:12]}", + unique_identifier_hash=legacy_hash, + source_markdown="## Old content", + search_space_id=space_id, + created_by_id=user_id, + embedding=[0.1] * _EMBEDDING_DIM, + status={"state": "ready"}, + ) + db_session.add(legacy_doc) + await db_session.flush() + original_id = legacy_doc.id + + connector_doc = _gmail_doc( + unique_id=msg_id, + search_space_id=space_id, + connector_id=db_connector.id, + user_id=user_id, + ) + + service = IndexingPipelineService(session=db_session) + await service.migrate_legacy_docs([connector_doc]) + + prepared = await service.prepare_for_indexing([connector_doc]) + assert len(prepared) == 1 + assert prepared[0].id == original_id + assert prepared[0].document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR + + native_hash = compute_identifier_hash( + DocumentType.GOOGLE_GMAIL_CONNECTOR.value, msg_id, space_id + ) + assert prepared[0].unique_identifier_hash == native_hash diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_index_batch.py b/surfsense_backend/tests/integration/indexing_pipeline/test_index_batch.py new file mode 100644 index 000000000..a40498769 --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_index_batch.py @@ -0,0 +1,55 @@ +"""Integration tests for IndexingPipelineService.index_batch().""" + +import pytest +from sqlalchemy import select + +from app.db import Document, DocumentStatus, DocumentType +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +pytestmark = pytest.mark.integration + + +@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +async def test_index_batch_creates_ready_documents( + db_session, db_search_space, make_connector_document, mocker +): + """index_batch prepares and indexes a batch, resulting in READY documents.""" + space_id = db_search_space.id + docs = [ + make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-batch-1", + search_space_id=space_id, + source_markdown="## Email 1\n\nBody", + ), + make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-batch-2", + search_space_id=space_id, + source_markdown="## Email 2\n\nDifferent body", + ), + ] + + service = IndexingPipelineService(session=db_session) + results = await service.index_batch(docs, llm=mocker.Mock()) + + assert len(results) == 2 + + result = await db_session.execute( + select(Document).filter(Document.search_space_id == space_id) + ) + rows = result.scalars().all() + assert len(rows) == 2 + + for row in rows: + assert DocumentStatus.is_state(row.status, DocumentStatus.READY) + assert row.content is not None + assert row.embedding is not None + + +@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +async def test_index_batch_empty_returns_empty(db_session, mocker): + """index_batch with empty input returns an empty list.""" + service = IndexingPipelineService(session=db_session) + results = await service.index_batch([], llm=mocker.Mock()) + assert results == [] diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_migrate_legacy_docs.py b/surfsense_backend/tests/integration/indexing_pipeline/test_migrate_legacy_docs.py new file mode 100644 index 000000000..8fc0e7586 --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_migrate_legacy_docs.py @@ -0,0 +1,92 @@ +"""Integration tests for IndexingPipelineService.migrate_legacy_docs().""" + +import pytest +from sqlalchemy import select + +from app.config import config as app_config +from app.db import Document, DocumentType +from app.indexing_pipeline.document_hashing import compute_identifier_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +_EMBEDDING_DIM = app_config.embedding_model_instance.dimension + +pytestmark = pytest.mark.integration + + +async def test_legacy_composio_gmail_doc_migrated_in_db( + db_session, db_search_space, db_user, make_connector_document +): + """A Composio Gmail doc in the DB gets its hash and type updated to native.""" + space_id = db_search_space.id + user_id = str(db_user.id) + unique_id = "msg-legacy-123" + + legacy_hash = compute_identifier_hash( + DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, unique_id, space_id + ) + native_hash = compute_identifier_hash( + DocumentType.GOOGLE_GMAIL_CONNECTOR.value, unique_id, space_id + ) + + legacy_doc = Document( + title="Old Gmail", + document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR, + content="legacy content", + content_hash=f"ch-{legacy_hash[:12]}", + unique_identifier_hash=legacy_hash, + search_space_id=space_id, + created_by_id=user_id, + embedding=[0.1] * _EMBEDDING_DIM, + status={"state": "ready"}, + ) + db_session.add(legacy_doc) + await db_session.flush() + doc_id = legacy_doc.id + + connector_doc = make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id=unique_id, + search_space_id=space_id, + ) + + service = IndexingPipelineService(session=db_session) + await service.migrate_legacy_docs([connector_doc]) + + result = await db_session.execute(select(Document).filter(Document.id == doc_id)) + reloaded = result.scalars().first() + + assert reloaded.unique_identifier_hash == native_hash + assert reloaded.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR + + +async def test_no_legacy_doc_is_noop( + db_session, db_search_space, make_connector_document +): + """When no legacy document exists, migrate_legacy_docs does nothing.""" + connector_doc = make_connector_document( + document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR, + unique_id="evt-no-legacy", + search_space_id=db_search_space.id, + ) + + service = IndexingPipelineService(session=db_session) + await service.migrate_legacy_docs([connector_doc]) + + result = await db_session.execute( + select(Document).filter(Document.search_space_id == db_search_space.id) + ) + assert result.scalars().all() == [] + + +async def test_non_google_type_is_skipped( + db_session, db_search_space, make_connector_document +): + """migrate_legacy_docs skips ConnectorDocuments that are not Google types.""" + connector_doc = make_connector_document( + document_type=DocumentType.CLICKUP_CONNECTOR, + unique_id="task-1", + search_space_id=db_search_space.id, + ) + + service = IndexingPipelineService(session=db_session) + await service.migrate_legacy_docs([connector_doc]) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch.py b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch.py new file mode 100644 index 000000000..dcd097d20 --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch.py @@ -0,0 +1,82 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.db import Document, DocumentType +from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def mock_session(): + return AsyncMock() + + +@pytest.fixture +def pipeline(mock_session): + return IndexingPipelineService(mock_session) + + +async def test_calls_prepare_then_index_per_document( + pipeline, make_connector_document +): + """index_batch calls prepare_for_indexing, then index() for each returned doc.""" + doc1 = make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-1", + search_space_id=1, + ) + doc2 = make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-2", + search_space_id=1, + ) + + orm1 = MagicMock(spec=Document) + orm1.unique_identifier_hash = compute_unique_identifier_hash(doc1) + orm2 = MagicMock(spec=Document) + orm2.unique_identifier_hash = compute_unique_identifier_hash(doc2) + + mock_llm = MagicMock() + + pipeline.prepare_for_indexing = AsyncMock(return_value=[orm1, orm2]) + pipeline.index = AsyncMock(side_effect=lambda doc, cdoc, llm: doc) + + results = await pipeline.index_batch([doc1, doc2], mock_llm) + + pipeline.prepare_for_indexing.assert_awaited_once_with([doc1, doc2]) + assert pipeline.index.await_count == 2 + assert results == [orm1, orm2] + + +async def test_empty_input_returns_empty(pipeline): + """Empty connector_docs list returns empty result.""" + pipeline.prepare_for_indexing = AsyncMock(return_value=[]) + + results = await pipeline.index_batch([], MagicMock()) + + assert results == [] + + +async def test_skips_document_without_matching_connector_doc( + pipeline, make_connector_document +): + """If prepare returns a doc whose hash has no matching ConnectorDocument, it's skipped.""" + doc1 = make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-1", + search_space_id=1, + ) + + orphan_orm = MagicMock(spec=Document) + orphan_orm.unique_identifier_hash = "nonexistent-hash" + + pipeline.prepare_for_indexing = AsyncMock(return_value=[orphan_orm]) + pipeline.index = AsyncMock() + + results = await pipeline.index_batch([doc1], MagicMock()) + + pipeline.index.assert_not_awaited() + assert results == [] diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_migrate_legacy_docs.py b/surfsense_backend/tests/unit/indexing_pipeline/test_migrate_legacy_docs.py new file mode 100644 index 000000000..9334fe678 --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_migrate_legacy_docs.py @@ -0,0 +1,127 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.db import Document, DocumentType +from app.indexing_pipeline.document_hashing import compute_identifier_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def mock_session(): + session = AsyncMock() + return session + + +@pytest.fixture +def pipeline(mock_session): + return IndexingPipelineService(mock_session) + + +def _make_execute_side_effect(doc_by_hash: dict): + """Return a side_effect for session.execute that resolves documents by hash.""" + + async def _side_effect(stmt): + result = MagicMock() + for h, doc in doc_by_hash.items(): + if h in str(stmt.compile(compile_kwargs={"literal_binds": True})): + result.scalars.return_value.first.return_value = doc + return result + result.scalars.return_value.first.return_value = None + return result + + return _side_effect + + +async def test_updates_hash_and_type_for_legacy_document( + pipeline, mock_session, make_connector_document +): + """Legacy Composio document gets unique_identifier_hash and document_type updated.""" + doc = make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-abc", + search_space_id=1, + ) + + legacy_hash = compute_identifier_hash("COMPOSIO_GMAIL_CONNECTOR", "msg-abc", 1) + native_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-abc", 1) + + existing = MagicMock(spec=Document) + existing.unique_identifier_hash = legacy_hash + existing.document_type = DocumentType.COMPOSIO_GMAIL_CONNECTOR + + result_mock = MagicMock() + result_mock.scalars.return_value.first.return_value = existing + mock_session.execute = AsyncMock(return_value=result_mock) + + await pipeline.migrate_legacy_docs([doc]) + + assert existing.unique_identifier_hash == native_hash + assert existing.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR + mock_session.commit.assert_awaited_once() + + +async def test_noop_when_no_legacy_document_exists( + pipeline, mock_session, make_connector_document +): + """No updates when no legacy Composio document is found in DB.""" + doc = make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-xyz", + search_space_id=1, + ) + + result_mock = MagicMock() + result_mock.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=result_mock) + + await pipeline.migrate_legacy_docs([doc]) + + mock_session.commit.assert_awaited_once() + + +async def test_skips_non_google_doc_types( + pipeline, mock_session, make_connector_document +): + """Non-Google doc types have no legacy mapping and trigger no DB query.""" + doc = make_connector_document( + document_type=DocumentType.SLACK_CONNECTOR, + unique_id="slack-123", + search_space_id=1, + ) + + await pipeline.migrate_legacy_docs([doc]) + + mock_session.execute.assert_not_awaited() + mock_session.commit.assert_awaited_once() + + +async def test_handles_all_three_google_types( + pipeline, mock_session, make_connector_document +): + """Each native Google type correctly maps to its Composio legacy type.""" + mappings = [ + (DocumentType.GOOGLE_GMAIL_CONNECTOR, "COMPOSIO_GMAIL_CONNECTOR"), + (DocumentType.GOOGLE_CALENDAR_CONNECTOR, "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"), + (DocumentType.GOOGLE_DRIVE_FILE, "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"), + ] + for native_type, expected_legacy in mappings: + doc = make_connector_document( + document_type=native_type, + unique_id="id-1", + search_space_id=1, + ) + + existing = MagicMock(spec=Document) + existing.document_type = DocumentType(expected_legacy) + + result_mock = MagicMock() + result_mock.scalars.return_value.first.return_value = existing + mock_session.execute = AsyncMock(return_value=result_mock) + mock_session.commit = AsyncMock() + + await pipeline.migrate_legacy_docs([doc]) + + assert existing.document_type == native_type From c3d5c865fdb18b6ea5c039a2e160a5e4cbdbd64f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:51:40 +0530 Subject: [PATCH 03/31] fix: update file skipping logic in Google Drive indexer - Modified the `_should_skip_file` function to prevent skipping of documents with a FAILED status, ensuring they are reprocessed even if their content remains unchanged. - Added a new integration test to verify that FAILED documents are not skipped during the indexing process. --- .../google_drive_indexer.py | 2 +- .../indexing_pipeline/test_drive_pipeline.py | 59 +++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 92c074812..af9528bb7 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -149,7 +149,7 @@ async def _should_skip_file( return True, f"File renamed: '{old_name}' → '{file_name}'" if not DocumentStatus.is_state(existing.status, DocumentStatus.READY): - existing.status = DocumentStatus.ready() + return False, None return True, "unchanged" diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py index 32af0b8c1..77128ebd9 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py @@ -108,3 +108,62 @@ async def test_drive_legacy_doc_migrated( DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id ) assert row.unique_identifier_hash == native_hash + + +async def test_should_skip_file_does_not_skip_failed_document( + db_session, db_search_space, db_user, +): + """A FAILED document with unchanged md5 must NOT be skipped — it needs reprocessing.""" + import importlib + import sys + import types + + pkg = "app.tasks.connector_indexers" + stub = pkg not in sys.modules + if stub: + mod = types.ModuleType(pkg) + mod.__path__ = ["app/tasks/connector_indexers"] + mod.__package__ = pkg + sys.modules[pkg] = mod + + try: + gdm = importlib.import_module( + "app.tasks.connector_indexers.google_drive_indexer" + ) + _should_skip_file = gdm._should_skip_file + finally: + if stub: + sys.modules.pop(pkg, None) + + space_id = db_search_space.id + file_id = "file-failed-drive" + md5 = "abc123deadbeef" + + doc_hash = compute_identifier_hash( + DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id + ) + failed_doc = Document( + title="Failed File.pdf", + document_type=DocumentType.GOOGLE_DRIVE_FILE, + content="LLM rate limit exceeded", + content_hash=f"ch-{doc_hash[:12]}", + unique_identifier_hash=doc_hash, + source_markdown="## Real content", + search_space_id=space_id, + created_by_id=str(db_user.id), + embedding=[0.1] * _EMBEDDING_DIM, + status=DocumentStatus.failed("LLM rate limit exceeded"), + document_metadata={ + "google_drive_file_id": file_id, + "google_drive_file_name": "Failed File.pdf", + "md5_checksum": md5, + }, + ) + db_session.add(failed_doc) + await db_session.flush() + + incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5} + + should_skip, _msg = await _should_skip_file(db_session, incoming_file, space_id) + + assert not should_skip, "FAILED documents must not be skipped even when content is unchanged" From bbd5ee8a1979c67a4ab43b1cadca904445a4008f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 25 Mar 2026 20:35:23 +0530 Subject: [PATCH 04/31] feat: enhance Google Calendar event update functionality - Introduced helper functions `_is_date_only` and `_build_time_body` to streamline the construction of event start and end times for all-day and timed events. - Refactored the `create_update_calendar_event_tool` to utilize the new helper functions, improving code readability and maintainability. - Updated the Google Calendar sync service to ensure proper handling of calendar IDs with a default fallback to "primary". - Modified the ApprovalCard component to simplify the construction of event update arguments, enhancing clarity and reducing redundancy. --- .../tools/google_calendar/update_event.py | 34 +++++++------- .../google_calendar/kb_sync_service.py | 4 +- .../hitl-edit-panel/hitl-edit-panel.tsx | 2 +- .../tool-ui/google-calendar/update-event.tsx | 46 ++++++++++++++----- 4 files changed, 55 insertions(+), 31 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py index 4b57cf2e3..a114c84f4 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/update_event.py @@ -14,6 +14,20 @@ from app.services.google_calendar import GoogleCalendarToolMetadataService logger = logging.getLogger(__name__) +def _is_date_only(value: str) -> bool: + """Return True when *value* looks like a bare date (YYYY-MM-DD) with no time component.""" + return len(value) <= 10 and "T" not in value + + +def _build_time_body(value: str, context: dict[str, Any] | Any) -> dict[str, str]: + """Build a Google Calendar start/end body using ``date`` for all-day + events and ``dateTime`` for timed events.""" + if _is_date_only(value): + return {"date": value} + tz = context.get("timezone", "UTC") if isinstance(context, dict) else "UTC" + return {"dateTime": value, "timeZone": tz} + + def create_update_calendar_event_tool( db_session: AsyncSession | None = None, search_space_id: int | None = None, @@ -255,25 +269,13 @@ def create_update_calendar_event_tool( if final_new_summary is not None: update_body["summary"] = final_new_summary if final_new_start_datetime is not None: - tz = ( - context.get("timezone", "UTC") - if isinstance(context, dict) - else "UTC" + update_body["start"] = _build_time_body( + final_new_start_datetime, context ) - update_body["start"] = { - "dateTime": final_new_start_datetime, - "timeZone": tz, - } if final_new_end_datetime is not None: - tz = ( - context.get("timezone", "UTC") - if isinstance(context, dict) - else "UTC" + update_body["end"] = _build_time_body( + final_new_end_datetime, context ) - update_body["end"] = { - "dateTime": final_new_end_datetime, - "timeZone": tz, - } if final_new_description is not None: update_body["description"] = final_new_description if final_new_location is not None: diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py index 59afa116e..3cda02b9b 100644 --- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py +++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py @@ -209,8 +209,8 @@ class GoogleCalendarKBSyncService: ) calendar_id = (document.document_metadata or {}).get( - "calendar_id", "primary" - ) + "calendar_id" + ) or "primary" live_event = await loop.run_in_executor( None, lambda: ( diff --git a/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx b/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx index 25e896842..e8bc1a6cd 100644 --- a/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx +++ b/surfsense_web/components/hitl-edit-panel/hitl-edit-panel.tsx @@ -185,7 +185,7 @@ function DateTimePickerField({ type="time" value={time} onChange={handleTimeChange} - className="w-[120px] text-sm shrink-0 pl-1.5 [&::-webkit-calendar-picker-indicator]:order-first [&::-webkit-calendar-picker-indicator]:mr-1" + className="w-[120px] text-sm shrink-0 appearance-none [&::-webkit-calendar-picker-indicator]:hidden [&::-webkit-calendar-picker-indicator]:appearance-none" /> ); diff --git a/surfsense_web/components/tool-ui/google-calendar/update-event.tsx b/surfsense_web/components/tool-ui/google-calendar/update-event.tsx index cc941bab8..661032628 100644 --- a/surfsense_web/components/tool-ui/google-calendar/update-event.tsx +++ b/surfsense_web/components/tool-ui/google-calendar/update-event.tsx @@ -253,6 +253,12 @@ function ApprovalCard({ String(effectiveNewDescription ?? "") !== (event?.description ?? ""); const buildFinalArgs = useCallback(() => { + const base = { + event_id: event?.event_id, + document_id: event?.document_id, + connector_id: account?.id, + }; + if (pendingEdits) { const attendeesArr = pendingEdits.attendees ? pendingEdits.attendees @@ -260,22 +266,38 @@ function ApprovalCard({ .map((e) => e.trim()) .filter(Boolean) : null; + const origAttendees = event?.attendees?.map((a) => a.email) ?? []; + return { - event_id: event?.event_id, - document_id: event?.document_id, - connector_id: account?.id, - new_summary: pendingEdits.summary || null, - new_description: pendingEdits.description || null, - new_start_datetime: pendingEdits.start_datetime || null, - new_end_datetime: pendingEdits.end_datetime || null, - new_location: pendingEdits.location || null, - new_attendees: attendeesArr, + ...base, + new_summary: + pendingEdits.summary && pendingEdits.summary !== (event?.summary ?? "") + ? pendingEdits.summary + : null, + new_description: + pendingEdits.description !== (event?.description ?? "") + ? pendingEdits.description || null + : null, + new_start_datetime: + pendingEdits.start_datetime && pendingEdits.start_datetime !== (event?.start ?? "") + ? pendingEdits.start_datetime + : null, + new_end_datetime: + pendingEdits.end_datetime && pendingEdits.end_datetime !== (event?.end ?? "") + ? pendingEdits.end_datetime + : null, + new_location: + pendingEdits.location !== (event?.location ?? "") + ? pendingEdits.location || null + : null, + new_attendees: + attendeesArr && attendeesArr.join(",") !== origAttendees.join(",") + ? attendeesArr + : null, }; } return { - event_id: event?.event_id, - document_id: event?.document_id, - connector_id: account?.id, + ...base, new_summary: actionArgs.new_summary ?? null, new_description: actionArgs.new_description ?? null, new_start_datetime: actionArgs.new_start_datetime ?? null, From e5cb6bfacf8c43f14372c058e5507cc7ab537771 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:33:49 +0530 Subject: [PATCH 05/31] feat: implement parallel document indexing in IndexingPipelineService - Added `index_batch_parallel` method to enable concurrent indexing of documents with bounded concurrency, improving performance and efficiency. - Refactored existing indexing logic to utilize `asyncio.to_thread` for non-blocking execution of embedding and chunking functions. - Introduced unit tests to validate the functionality of the new parallel indexing method, ensuring robustness and error handling during document processing. --- .../indexing_pipeline_service.py | 6 +- .../test_index_batch_parallel.py | 70 +++++++++++++++++++ surfsense_web/contracts/enums/toolIcons.tsx | 4 +- 3 files changed, 76 insertions(+), 4 deletions(-) create mode 100644 surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index c6a29f204..1a61e779e 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import time from datetime import UTC, datetime @@ -257,13 +258,14 @@ class IndexingPipelineService: ) t_step = time.perf_counter() - chunk_texts = chunk_text( + chunk_texts = await asyncio.to_thread( + chunk_text, connector_doc.source_markdown, use_code_chunker=connector_doc.should_use_code_chunker, ) texts_to_embed = [content, *chunk_texts] - embeddings = embed_texts(texts_to_embed) + embeddings = await asyncio.to_thread(embed_texts, texts_to_embed) summary_embedding, *chunk_embeddings = embeddings chunks = [ diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py new file mode 100644 index 000000000..7e23383ac --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py @@ -0,0 +1,70 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.config import config as app_config +from app.db import Document, DocumentStatus, DocumentType +from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +_EMBEDDING_DIM = app_config.embedding_model_instance.dimension + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def mock_session(): + session = AsyncMock() + session.refresh = AsyncMock() + return session + + +@pytest.fixture +def pipeline(mock_session): + return IndexingPipelineService(mock_session) + + +async def test_index_calls_embed_and_chunk_via_to_thread( + pipeline, make_connector_document, monkeypatch +): + """index() runs embed_texts and chunk_text via asyncio.to_thread, not blocking the loop.""" + to_thread_calls = [] + original_to_thread = asyncio.to_thread + + async def tracking_to_thread(func, *args, **kwargs): + to_thread_calls.append(func.__name__) + return await original_to_thread(func, *args, **kwargs) + + monkeypatch.setattr(asyncio, "to_thread", tracking_to_thread) + + monkeypatch.setattr( + "app.indexing_pipeline.indexing_pipeline_service.summarize_document", + AsyncMock(return_value="Summary."), + ) + mock_chunk = MagicMock(return_value=["chunk1"]) + mock_chunk.__name__ = "chunk_text" + monkeypatch.setattr( + "app.indexing_pipeline.indexing_pipeline_service.chunk_text", + mock_chunk, + ) + mock_embed = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]) + mock_embed.__name__ = "embed_texts" + monkeypatch.setattr( + "app.indexing_pipeline.indexing_pipeline_service.embed_texts", + mock_embed, + ) + + connector_doc = make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id="msg-1", + search_space_id=1, + ) + document = MagicMock(spec=Document) + document.id = 1 + document.status = DocumentStatus.pending() + + await pipeline.index(document, connector_doc, llm=MagicMock()) + + assert "chunk_text" in to_thread_calls + assert "embed_texts" in to_thread_calls diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 90ec7a544..6ca6550b5 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -5,10 +5,10 @@ import { FileText, Film, Globe, + ImageIcon, type LucideIcon, Podcast, ScanLine, - Sparkles, Wrench, } from "lucide-react"; @@ -17,7 +17,7 @@ const TOOL_ICONS: Record = { generate_podcast: Podcast, generate_video_presentation: Film, generate_report: FileText, - generate_image: Sparkles, + generate_image: ImageIcon, scrape_webpage: ScanLine, web_search: Globe, search_surfsense_docs: BookOpen, From 4fd776e7ef9ef6cc5786194494fe20012d1c7046 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:34:04 +0530 Subject: [PATCH 06/31] feat: implement parallel indexing for Google Calendar and Gmail connectors - Refactored Google Calendar and Gmail indexers to utilize the new `index_batch_parallel` method for concurrent document indexing, enhancing performance. - Updated the indexing logic to replace serial processing with parallel execution, allowing for improved efficiency in handling multiple documents. - Adjusted logging and error handling to accommodate the new parallel processing approach, ensuring robust operation during indexing. - Enhanced unit tests to validate the functionality of the parallel indexing method and its integration with existing workflows. --- .../indexing_pipeline_service.py | 104 ++++++++++++++++ .../google_calendar_indexer.py | 58 ++------- .../google_gmail_indexer.py | 59 ++------- .../test_index_batch_parallel.py | 116 ++++++++++++++++++ 4 files changed, 242 insertions(+), 95 deletions(-) diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 1a61e779e..bd6086892 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -1,6 +1,8 @@ import asyncio import contextlib +import logging import time +from collections.abc import Awaitable, Callable from datetime import UTC, datetime from sqlalchemy import delete, select @@ -327,3 +329,105 @@ class IndexingPipelineService: await self.session.refresh(document) return document + + async def index_batch_parallel( + self, + connector_docs: list[ConnectorDocument], + get_llm: Callable[[AsyncSession], Awaitable], + *, + max_concurrency: int = 4, + on_heartbeat: Callable[[int], Awaitable[None]] | None = None, + heartbeat_interval: float = 30.0, + ) -> tuple[list[Document], int, int]: + """Index documents in parallel with bounded concurrency. + + Phase 1 (serial): prepare_for_indexing using self.session. + Phase 2 (parallel): index each document in an isolated session, + bounded by a semaphore to avoid overwhelming APIs/DB. + """ + logger = logging.getLogger(__name__) + + doc_map = { + compute_unique_identifier_hash(cd): cd for cd in connector_docs + } + documents = await self.prepare_for_indexing(connector_docs) + + if not documents: + return [], 0, 0 + + from app.tasks.celery_tasks import get_celery_session_maker + + sem = asyncio.Semaphore(max_concurrency) + lock = asyncio.Lock() + indexed_count = 0 + failed_count = 0 + results: list[Document] = [] + last_heartbeat = time.time() + + async def _index_one(document: Document) -> Document | Exception: + nonlocal indexed_count, failed_count, last_heartbeat + + connector_doc = doc_map.get(document.unique_identifier_hash) + if connector_doc is None: + logger.warning( + "No matching ConnectorDocument for document %s, skipping", + document.id, + ) + async with lock: + failed_count += 1 + return document + + async with sem: + session_maker = get_celery_session_maker() + async with session_maker() as isolated_session: + try: + refetched = await isolated_session.get( + Document, document.id + ) + if refetched is None: + async with lock: + failed_count += 1 + return document + + llm = await get_llm(isolated_session) + iso_pipeline = IndexingPipelineService(isolated_session) + result = await iso_pipeline.index( + refetched, connector_doc, llm + ) + + async with lock: + if DocumentStatus.is_state( + result.status, DocumentStatus.READY + ): + indexed_count += 1 + else: + failed_count += 1 + + if on_heartbeat: + now = time.time() + if now - last_heartbeat >= heartbeat_interval: + await on_heartbeat(indexed_count) + last_heartbeat = now + + return result + except Exception as exc: + logger.error( + "Parallel index failed for doc %s: %s", + document.id, + exc, + exc_info=True, + ) + async with lock: + failed_count += 1 + return exc + + tasks = [_index_one(doc) for doc in documents] + outcomes = await asyncio.gather(*tasks, return_exceptions=True) + + for outcome in outcomes: + if isinstance(outcome, Document): + results.append(outcome) + elif isinstance(outcome, Exception): + pass + + return results, indexed_count, failed_count diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index a69b33bdc..61b1ccb2b 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -5,7 +5,6 @@ Uses the shared IndexingPipelineService for document deduplication, summarization, chunking, and embedding. """ -import time from collections.abc import Awaitable, Callable from datetime import datetime, timedelta @@ -16,10 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.google_calendar_connector import GoogleCalendarConnector from app.db import DocumentType, SearchSourceConnectorType from app.indexing_pipeline.connector_document import ConnectorDocument -from app.indexing_pipeline.document_hashing import ( - compute_content_hash, - compute_unique_identifier_hash, -) +from app.indexing_pipeline.document_hashing import compute_content_hash from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService @@ -399,53 +395,21 @@ async def index_google_calendar_events( documents_skipped += 1 continue - # ── Pipeline: migrate legacy docs + prepare + index ─────────── + # ── Pipeline: migrate legacy docs + parallel index ───────────── pipeline = IndexingPipelineService(session) await pipeline.migrate_legacy_docs(connector_docs) - documents = await pipeline.prepare_for_indexing(connector_docs) + async def _get_llm(s): + return await get_user_long_context_llm(s, user_id, search_space_id) - doc_map = { - compute_unique_identifier_hash(cd): cd for cd in connector_docs - } - - documents_indexed = 0 - documents_failed = 0 - last_heartbeat_time = time.time() - - for document in documents: - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - connector_doc = doc_map.get(document.unique_identifier_hash) - if connector_doc is None: - logger.warning( - f"No matching ConnectorDocument for document {document.id}, skipping" - ) - documents_failed += 1 - continue - - try: - user_llm = await get_user_long_context_llm( - session, user_id, search_space_id - ) - await pipeline.index(document, connector_doc, user_llm) - documents_indexed += 1 - - if documents_indexed % 10 == 0: - logger.info( - f"Committing batch: {documents_indexed} Google Calendar events processed so far" - ) - await session.commit() - - except Exception as e: - logger.error(f"Error processing Calendar event: {e!s}", exc_info=True) - documents_failed += 1 - continue + _, documents_indexed, documents_failed = await pipeline.index_batch_parallel( + connector_docs, + _get_llm, + max_concurrency=3, + on_heartbeat=on_heartbeat_callback, + heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS, + ) # ── Finalize ────────────────────────────────────────────────── await update_connector_last_indexed(session, connector, update_last_indexed) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index 96cc1cbb4..0d77ad3cd 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -5,8 +5,6 @@ Uses the shared IndexingPipelineService for document deduplication, summarization, chunking, and embedding. """ -import logging -import time from collections.abc import Awaitable, Callable from datetime import datetime @@ -17,10 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.google_gmail_connector import GoogleGmailConnector from app.db import DocumentType, SearchSourceConnectorType from app.indexing_pipeline.connector_document import ConnectorDocument -from app.indexing_pipeline.document_hashing import ( - compute_content_hash, - compute_unique_identifier_hash, -) +from app.indexing_pipeline.document_hashing import compute_content_hash from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService @@ -336,53 +331,21 @@ async def index_google_gmail_messages( documents_skipped += 1 continue - # ── Pipeline: migrate legacy docs + prepare + index ─────────── + # ── Pipeline: migrate legacy docs + parallel index ───────────── pipeline = IndexingPipelineService(session) await pipeline.migrate_legacy_docs(connector_docs) - documents = await pipeline.prepare_for_indexing(connector_docs) + async def _get_llm(s): + return await get_user_long_context_llm(s, user_id, search_space_id) - doc_map = { - compute_unique_identifier_hash(cd): cd for cd in connector_docs - } - - documents_indexed = 0 - documents_failed = 0 - last_heartbeat_time = time.time() - - for document in documents: - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - connector_doc = doc_map.get(document.unique_identifier_hash) - if connector_doc is None: - logger.warning( - f"No matching ConnectorDocument for document {document.id}, skipping" - ) - documents_failed += 1 - continue - - try: - user_llm = await get_user_long_context_llm( - session, user_id, search_space_id - ) - await pipeline.index(document, connector_doc, user_llm) - documents_indexed += 1 - - if documents_indexed % 10 == 0: - logger.info( - f"Committing batch: {documents_indexed} Gmail messages processed so far" - ) - await session.commit() - - except Exception as e: - logger.error(f"Error processing Gmail message: {e!s}", exc_info=True) - documents_failed += 1 - continue + _, documents_indexed, documents_failed = await pipeline.index_batch_parallel( + connector_docs, + _get_llm, + max_concurrency=3, + on_heartbeat=on_heartbeat_callback, + heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS, + ) # ── Finalize ────────────────────────────────────────────────── await update_connector_last_indexed(session, connector, update_last_indexed) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py index 7e23383ac..3148812f8 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py @@ -25,6 +25,15 @@ def pipeline(mock_session): return IndexingPipelineService(mock_session) +def _make_orm_doc(connector_doc, doc_id): + """Create a MagicMock Document bound to a ConnectorDocument's hash.""" + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc) + doc.status = DocumentStatus.pending() + return doc + + async def test_index_calls_embed_and_chunk_via_to_thread( pipeline, make_connector_document, monkeypatch ): @@ -68,3 +77,110 @@ async def test_index_calls_embed_and_chunk_via_to_thread( assert "chunk_text" in to_thread_calls assert "embed_texts" in to_thread_calls + + +def _mock_session_factory(orm_docs_by_id): + """Replace get_celery_session_maker with a two-level callable. + + get_celery_session_maker() -> session_maker + session_maker() -> async context manager yielding a mock session + """ + + def _get_maker(): + def _make_session(): + session = MagicMock() + session.get = AsyncMock( + side_effect=lambda model, doc_id: orm_docs_by_id.get(doc_id) + ) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=session) + ctx.__aexit__ = AsyncMock(return_value=False) + return ctx + + return _make_session + + return _get_maker + + +async def test_batch_parallel_indexes_all_documents( + pipeline, make_connector_document, monkeypatch +): + """index_batch_parallel indexes all documents and returns correct counts.""" + docs = [ + make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id=f"msg-{i}", + search_space_id=1, + ) + for i in range(3) + ] + + orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)] + pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs) + + orm_by_id = {d.id: d for d in orm_docs} + monkeypatch.setattr( + "app.tasks.celery_tasks.get_celery_session_maker", + _mock_session_factory(orm_by_id), + ) + + index_calls = [] + + async def fake_index(self, document, connector_doc, llm): + index_calls.append(document.id) + document.status = DocumentStatus.ready() + return document + + monkeypatch.setattr(IndexingPipelineService, "index", fake_index) + + async def mock_get_llm(session): + return MagicMock() + + _, indexed, failed = await pipeline.index_batch_parallel( + docs, mock_get_llm, max_concurrency=2 + ) + + assert indexed == 3 + assert failed == 0 + assert sorted(index_calls) == [1, 2, 3] + + +async def test_batch_parallel_one_failure_does_not_affect_others( + pipeline, make_connector_document, monkeypatch +): + """One document failure doesn't prevent other documents from being indexed.""" + docs = [ + make_connector_document( + document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR, + unique_id=f"msg-{i}", + search_space_id=1, + ) + for i in range(3) + ] + + orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)] + pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs) + + orm_by_id = {d.id: d for d in orm_docs} + monkeypatch.setattr( + "app.tasks.celery_tasks.get_celery_session_maker", + _mock_session_factory(orm_by_id), + ) + + async def failing_index(self, document, connector_doc, llm): + if document.id == 2: + raise RuntimeError("LLM exploded") + document.status = DocumentStatus.ready() + return document + + monkeypatch.setattr(IndexingPipelineService, "index", failing_index) + + async def mock_get_llm(session): + return MagicMock() + + _, indexed, failed = await pipeline.index_batch_parallel( + docs, mock_get_llm, max_concurrency=4 + ) + + assert indexed == 2 + assert failed == 1 From bd6e335cb3d0f227bc59e70877297b1ea8221a10 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 26 Mar 2026 23:10:49 +0530 Subject: [PATCH 07/31] feat: enhance performance logging in indexing pipeline - Added performance logging to the `index_batch_parallel` method, capturing metrics for document indexing duration and concurrency. - Introduced timing measurements for both the overall indexing process and the parallel document gathering phase, improving observability of the indexing workflow. - Updated logging statements to provide detailed insights into the number of documents processed, indexed, and failed during the indexing operation. --- .../indexing_pipeline_service.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index bd6086892..9a945dd25 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -346,6 +346,8 @@ class IndexingPipelineService: bounded by a semaphore to avoid overwhelming APIs/DB. """ logger = logging.getLogger(__name__) + perf = get_perf_logger() + t_total = time.perf_counter() doc_map = { compute_unique_identifier_hash(cd): cd for cd in connector_docs @@ -422,7 +424,17 @@ class IndexingPipelineService: return exc tasks = [_index_one(doc) for doc in documents] + t_parallel = time.perf_counter() outcomes = await asyncio.gather(*tasks, return_exceptions=True) + perf.info( + "[indexing] index_batch_parallel gather docs=%d concurrency=%d " + "indexed=%d failed=%d in %.3fs", + len(documents), + max_concurrency, + indexed_count, + failed_count, + time.perf_counter() - t_parallel, + ) for outcome in outcomes: if isinstance(outcome, Document): @@ -430,4 +442,13 @@ class IndexingPipelineService: elif isinstance(outcome, Exception): pass + perf.info( + "[indexing] index_batch_parallel TOTAL input=%d prepared=%d " + "indexed=%d failed=%d in %.3fs", + len(connector_docs), + len(documents), + indexed_count, + failed_count, + time.perf_counter() - t_total, + ) return results, indexed_count, failed_count From c0169620641e30758a0211150a5a8cb6e64c2db3 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 26 Mar 2026 23:53:26 +0530 Subject: [PATCH 08/31] feat: implement parallel file downloading and indexing in Google Drive indexer - Added `_download_files_parallel` function to enable concurrent downloading of files from Google Drive, improving efficiency in document processing. - Introduced `_download_and_index` function to handle the parallel downloading and indexing phases, streamlining the overall workflow. - Updated `_index_full_scan` and `_index_with_delta_sync` methods to utilize the new parallel downloading functionality, enhancing performance. - Added unit tests to validate the new parallel downloading and indexing logic, ensuring robustness and error handling during document processing. --- .../google_drive_indexer.py | 187 +++++-- .../tests/unit/connector_indexers/__init__.py | 0 .../tests/unit/connector_indexers/conftest.py | 34 ++ .../test_google_drive_parallel.py | 466 ++++++++++++++++++ 4 files changed, 652 insertions(+), 35 deletions(-) create mode 100644 surfsense_backend/tests/unit/connector_indexers/__init__.py create mode 100644 surfsense_backend/tests/unit/connector_indexers/conftest.py create mode 100644 surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index af9528bb7..8ba08533f 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -5,6 +5,7 @@ checks and rename-only detection. download_and_extract_content() returns markdown which is fed into ConnectorDocument -> pipeline. """ +import asyncio import logging import time from collections.abc import Awaitable, Callable @@ -190,6 +191,68 @@ def _build_connector_doc( ) +async def _download_files_parallel( + drive_client: GoogleDriveClient, + files: list[dict], + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, + max_concurrency: int = 5, + on_heartbeat: HeartbeatCallbackType | None = None, +) -> tuple[list[ConnectorDocument], int]: + """Download and ETL files in parallel, returning ConnectorDocuments. + + Returns (connector_docs, download_failed_count). + """ + results: list[ConnectorDocument] = [] + sem = asyncio.Semaphore(max_concurrency) + last_heartbeat = time.time() + completed_count = 0 + hb_lock = asyncio.Lock() + + async def _download_one(file: dict) -> ConnectorDocument | None: + nonlocal last_heartbeat, completed_count + async with sem: + markdown, drive_metadata, error = await download_and_extract_content( + drive_client, file + ) + if error or not markdown: + return None + doc = _build_connector_doc( + file, + markdown, + drive_metadata, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=enable_summary, + ) + async with hb_lock: + completed_count += 1 + if on_heartbeat: + now = time.time() + if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS: + await on_heartbeat(completed_count) + last_heartbeat = now + return doc + + tasks = [_download_one(f) for f in files] + outcomes = await asyncio.gather(*tasks, return_exceptions=True) + + failed = 0 + for outcome in outcomes: + if isinstance(outcome, Exception): + failed += 1 + elif outcome is None: + failed += 1 + else: + results.append(outcome) + + return results, failed + + async def _process_single_file( drive_client: GoogleDriveClient, session: AsyncSession, @@ -283,6 +346,47 @@ async def _remove_document(session: AsyncSession, file_id: str, search_space_id: logger.info(f"Removed deleted file document: {file_id}") +async def _download_and_index( + drive_client: GoogleDriveClient, + session: AsyncSession, + files: list[dict], + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, + on_heartbeat: HeartbeatCallbackType | None = None, +) -> tuple[int, int]: + """Phase 2+3: parallel download then parallel indexing. + + Returns (batch_indexed, total_failed). + """ + connector_docs, download_failed = await _download_files_parallel( + drive_client, + files, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=enable_summary, + on_heartbeat=on_heartbeat, + ) + + batch_indexed = 0 + batch_failed = 0 + if connector_docs: + pipeline = IndexingPipelineService(session) + + async def _get_llm(s): + return await get_user_long_context_llm(s, user_id, search_space_id) + + _, batch_indexed, batch_failed = await pipeline.index_batch_parallel( + connector_docs, _get_llm, max_concurrency=3, + on_heartbeat=on_heartbeat, + ) + + return batch_indexed, download_failed + batch_failed + + # --------------------------------------------------------------------------- # Scan strategies # --------------------------------------------------------------------------- @@ -310,11 +414,13 @@ async def _index_full_scan( {"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders}, ) - indexed = 0 + # ------------------------------------------------------------------ + # Phase 1 (serial): collect files, run skip checks, track renames + # ------------------------------------------------------------------ + renamed_count = 0 skipped = 0 - failed = 0 files_processed = 0 - last_heartbeat = time.time() + files_to_download: list[dict] = [] folders_to_process = [(folder_id, folder_name)] first_error: str | None = None @@ -346,22 +452,15 @@ async def _index_full_scan( files_processed += 1 - if on_heartbeat_callback: - now = time.time() - if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(indexed) - last_heartbeat = now + skip, msg = await _should_skip_file(session, file, search_space_id) + if skip: + if msg and "renamed" in msg.lower(): + renamed_count += 1 + else: + skipped += 1 + continue - i, s, f = await _process_single_file( - drive_client, session, file, - connector_id, search_space_id, user_id, enable_summary, - ) - indexed += i - skipped += s - failed += f - - if indexed > 0 and indexed % 10 == 0: - await session.commit() + files_to_download.append(file) page_token = next_token if not page_token: @@ -375,6 +474,17 @@ async def _index_full_scan( ) raise Exception(f"Failed to list Google Drive files: {first_error}") + # ------------------------------------------------------------------ + # Phase 2+3 (parallel): download, ETL, index + # ------------------------------------------------------------------ + batch_indexed, failed = await _download_and_index( + drive_client, session, files_to_download, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + on_heartbeat=on_heartbeat_callback, + ) + + indexed = renamed_count + batch_indexed logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed") return indexed, skipped @@ -416,11 +526,14 @@ async def _index_with_delta_sync( return 0, 0 logger.info(f"Processing {len(changes)} changes") - indexed = 0 + + # ------------------------------------------------------------------ + # Phase 1 (serial): handle removals, collect files for download + # ------------------------------------------------------------------ + renamed_count = 0 skipped = 0 - failed = 0 + files_to_download: list[dict] = [] files_processed = 0 - last_heartbeat = time.time() for change in changes: if files_processed >= max_files: @@ -438,23 +551,27 @@ async def _index_with_delta_sync( if not file: continue - if on_heartbeat_callback: - now = time.time() - if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(indexed) - last_heartbeat = now + skip, msg = await _should_skip_file(session, file, search_space_id) + if skip: + if msg and "renamed" in msg.lower(): + renamed_count += 1 + else: + skipped += 1 + continue - i, s, f = await _process_single_file( - drive_client, session, file, - connector_id, search_space_id, user_id, enable_summary, - ) - indexed += i - skipped += s - failed += f + files_to_download.append(file) - if indexed > 0 and indexed % 10 == 0: - await session.commit() + # ------------------------------------------------------------------ + # Phase 2+3 (parallel): download, ETL, index + # ------------------------------------------------------------------ + batch_indexed, failed = await _download_and_index( + drive_client, session, files_to_download, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + on_heartbeat=on_heartbeat_callback, + ) + indexed = renamed_count + batch_indexed logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed") return indexed, skipped diff --git a/surfsense_backend/tests/unit/connector_indexers/__init__.py b/surfsense_backend/tests/unit/connector_indexers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/connector_indexers/conftest.py b/surfsense_backend/tests/unit/connector_indexers/conftest.py new file mode 100644 index 000000000..3e27eaf74 --- /dev/null +++ b/surfsense_backend/tests/unit/connector_indexers/conftest.py @@ -0,0 +1,34 @@ +"""Pre-register the connector_indexers package to bypass a circular import +in its ``__init__.py`` (airtable_indexer -> routes -> connector_indexers). + +This lets tests import individual indexer modules (e.g. +``google_drive_indexer``) without triggering the full package init. +""" + +import sys +import types +from pathlib import Path + +_BACKEND = Path(__file__).resolve().parents[3] + + +def _stub_package(dotted: str, fs_dir: Path) -> None: + if dotted not in sys.modules: + mod = types.ModuleType(dotted) + mod.__path__ = [str(fs_dir)] + mod.__package__ = dotted + sys.modules[dotted] = mod + + parts = dotted.split(".") + if len(parts) > 1: + parent_dotted = ".".join(parts[:-1]) + parent = sys.modules.get(parent_dotted) + if parent is not None: + setattr(parent, parts[-1], sys.modules[dotted]) + + +_stub_package("app.tasks", _BACKEND / "app" / "tasks") +_stub_package( + "app.tasks.connector_indexers", + _BACKEND / "app" / "tasks" / "connector_indexers", +) diff --git a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py new file mode 100644 index 000000000..22e900406 --- /dev/null +++ b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py @@ -0,0 +1,466 @@ +"""Tests for parallel download + indexing in the Google Drive indexer.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.tasks.connector_indexers.google_drive_indexer import ( + _download_files_parallel, + _index_full_scan, + _index_with_delta_sync, +) + +pytestmark = pytest.mark.unit + +_USER_ID = "00000000-0000-0000-0000-000000000001" +_CONNECTOR_ID = 42 +_SEARCH_SPACE_ID = 1 + + +def _make_file_dict(file_id: str, name: str, mime: str = "text/plain") -> dict: + return {"id": file_id, "name": name, "mimeType": mime} + + +def _mock_extract_ok(file_id: str, file_name: str): + """Return a successful (markdown, metadata, None) tuple.""" + return ( + f"# Content of {file_name}", + {"google_drive_file_id": file_id, "google_drive_file_name": file_name}, + None, + ) + + +@pytest.fixture +def mock_drive_client(): + return MagicMock() + + +@pytest.fixture +def patch_extract(monkeypatch): + """Provide a helper to set the download_and_extract_content mock.""" + def _patch(side_effect=None, return_value=None): + mock = AsyncMock(side_effect=side_effect, return_value=return_value) + monkeypatch.setattr( + "app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content", + mock, + ) + return mock + return _patch + + +async def test_single_file_returns_one_connector_document( + mock_drive_client, patch_extract, +): + """Tracer bullet: downloading one file produces one ConnectorDocument.""" + patch_extract(return_value=_mock_extract_ok("f1", "test.txt")) + + docs, failed = await _download_files_parallel( + mock_drive_client, + [_make_file_dict("f1", "test.txt")], + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + assert len(docs) == 1 + assert failed == 0 + assert docs[0].title == "test.txt" + assert docs[0].unique_id == "f1" + + +async def test_multiple_files_all_produce_documents( + mock_drive_client, patch_extract, +): + """All files are downloaded and converted to ConnectorDocuments.""" + files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)] + patch_extract( + side_effect=[_mock_extract_ok(f"f{i}", f"file{i}.txt") for i in range(3)] + ) + + docs, failed = await _download_files_parallel( + mock_drive_client, + files, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + assert len(docs) == 3 + assert failed == 0 + assert {d.unique_id for d in docs} == {"f0", "f1", "f2"} + + +async def test_one_download_exception_does_not_block_others( + mock_drive_client, patch_extract, +): + """A RuntimeError in one download still lets the other files succeed.""" + files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)] + patch_extract( + side_effect=[ + _mock_extract_ok("f0", "file0.txt"), + RuntimeError("network timeout"), + _mock_extract_ok("f2", "file2.txt"), + ] + ) + + docs, failed = await _download_files_parallel( + mock_drive_client, + files, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + assert len(docs) == 2 + assert failed == 1 + assert {d.unique_id for d in docs} == {"f0", "f2"} + + +async def test_etl_error_counts_as_download_failure( + mock_drive_client, patch_extract, +): + """download_and_extract_content returning an error is counted as failed.""" + files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")] + patch_extract( + side_effect=[ + _mock_extract_ok("f0", "good.txt"), + (None, {}, "ETL failed"), + ] + ) + + docs, failed = await _download_files_parallel( + mock_drive_client, + files, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + assert len(docs) == 1 + assert failed == 1 + + +async def test_concurrency_bounded_by_semaphore( + mock_drive_client, monkeypatch, +): + """Peak concurrent downloads never exceeds max_concurrency.""" + lock = asyncio.Lock() + active = 0 + peak = 0 + + async def _slow_extract(client, file): + nonlocal active, peak + async with lock: + active += 1 + peak = max(peak, active) + await asyncio.sleep(0.05) + async with lock: + active -= 1 + fid = file["id"] + return _mock_extract_ok(fid, file["name"]) + + monkeypatch.setattr( + "app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content", + _slow_extract, + ) + + files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(6)] + + docs, failed = await _download_files_parallel( + mock_drive_client, + files, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + max_concurrency=2, + ) + + assert len(docs) == 6 + assert failed == 0 + assert peak <= 2, f"Peak concurrency was {peak}, expected <= 2" + + +async def test_heartbeat_fires_during_parallel_downloads( + mock_drive_client, monkeypatch, +): + """on_heartbeat is called at least once when downloads take time.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + monkeypatch.setattr(_mod, "HEARTBEAT_INTERVAL_SECONDS", 0) + + async def _slow_extract(client, file): + await asyncio.sleep(0.05) + return _mock_extract_ok(file["id"], file["name"]) + + monkeypatch.setattr( + "app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content", + _slow_extract, + ) + + heartbeat_calls: list[int] = [] + + async def _on_heartbeat(count: int): + heartbeat_calls.append(count) + + files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)] + + docs, failed = await _download_files_parallel( + mock_drive_client, + files, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + on_heartbeat=_on_heartbeat, + ) + + assert len(docs) == 3 + assert failed == 0 + assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once" + + +# --------------------------------------------------------------------------- +# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline +# --------------------------------------------------------------------------- + +def _folder_dict(file_id: str, name: str) -> dict: + return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"} + + +@pytest.fixture +def full_scan_mocks(mock_drive_client, monkeypatch): + """Wire up all mocks needed to call _index_full_scan in isolation.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_progress = AsyncMock() + mock_log_entry = MagicMock() + + skip_results: dict[str, tuple[bool, str | None]] = {} + + async def _fake_skip(session, file, search_space_id): + return skip_results.get(file["id"], (False, None)) + + monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip) + + download_mock = AsyncMock(return_value=([], 0)) + monkeypatch.setattr(_mod, "_download_files_parallel", download_mock) + + batch_mock = AsyncMock(return_value=([], 0, 0)) + pipeline_mock = MagicMock() + pipeline_mock.index_batch_parallel = batch_mock + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock), + ) + + monkeypatch.setattr( + _mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()), + ) + + return { + "drive_client": mock_drive_client, + "session": mock_session, + "connector": mock_connector, + "task_logger": mock_task_logger, + "log_entry": mock_log_entry, + "skip_results": skip_results, + "download_mock": download_mock, + "batch_mock": batch_mock, + "pipeline_mock": pipeline_mock, + } + + +async def _run_full_scan(mocks, *, max_files=500, include_subfolders=False): + return await _index_full_scan( + mocks["drive_client"], + mocks["session"], + mocks["connector"], + _CONNECTOR_ID, + _SEARCH_SPACE_ID, + _USER_ID, + "folder-root", + "My Folder", + mocks["task_logger"], + mocks["log_entry"], + max_files, + include_subfolders=include_subfolders, + enable_summary=True, + ) + + +async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch): + """Full scan collects files serially, downloads and indexes in parallel, + and returns correct (indexed, skipped) with renames counted as indexed.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + page_files = [ + _folder_dict("folder1", "SubFolder"), + _make_file_dict("skip1", "unchanged.txt"), + _make_file_dict("rename1", "renamed.txt"), + _make_file_dict("new1", "new1.txt"), + _make_file_dict("new2", "new2.txt"), + ] + + monkeypatch.setattr( + _mod, "get_files_in_folder", + AsyncMock(return_value=(page_files, None, None)), + ) + + full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged") + full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old' → 'renamed.txt'") + + mock_docs = [MagicMock(), MagicMock()] + full_scan_mocks["download_mock"].return_value = (mock_docs, 0) + full_scan_mocks["batch_mock"].return_value = ([], 2, 0) + + indexed, skipped = await _run_full_scan(full_scan_mocks) + + assert indexed == 3 # 1 renamed + 2 from batch + assert skipped == 1 # 1 unchanged + + full_scan_mocks["download_mock"].assert_called_once() + call_files = full_scan_mocks["download_mock"].call_args[0][1] + assert len(call_files) == 2 + assert {f["id"] for f in call_files} == {"new1", "new2"} + + +async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch): + """Only max_files non-folder files are processed; the rest are ignored.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)] + + monkeypatch.setattr( + _mod, "get_files_in_folder", + AsyncMock(return_value=(page_files, None, None)), + ) + + full_scan_mocks["download_mock"].return_value = ([], 0) + full_scan_mocks["batch_mock"].return_value = ([], 0, 0) + + await _run_full_scan(full_scan_mocks, max_files=3) + + download_call_files = full_scan_mocks["download_mock"].call_args[0][1] + assert len(download_call_files) == 3 + + +async def test_full_scan_uses_max_concurrency_3_for_indexing( + full_scan_mocks, monkeypatch, +): + """index_batch_parallel is called with max_concurrency=3.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + page_files = [_make_file_dict("f1", "file1.txt")] + monkeypatch.setattr( + _mod, "get_files_in_folder", + AsyncMock(return_value=(page_files, None, None)), + ) + + mock_docs = [MagicMock()] + full_scan_mocks["download_mock"].return_value = (mock_docs, 0) + full_scan_mocks["batch_mock"].return_value = ([], 1, 0) + + await _run_full_scan(full_scan_mocks) + + call_kwargs = full_scan_mocks["batch_mock"].call_args + assert call_kwargs[1].get("max_concurrency") == 3 or ( + len(call_kwargs[0]) > 2 and call_kwargs[0][2] == 3 + ) + + +# --------------------------------------------------------------------------- +# Slice 7 -- _index_with_delta_sync three-phase pipeline +# --------------------------------------------------------------------------- + +async def test_delta_sync_removals_serial_rest_parallel(monkeypatch): + """Removed/trashed changes call _remove_document; the rest go through + _download_files_parallel and index_batch_parallel.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + changes = [ + {"fileId": "del1", "removed": True}, + {"fileId": "del2", "file": {"id": "del2", "trashed": True}}, + {"fileId": "trash1", "file": {"id": "trash1", "trashed": True}}, + {"fileId": "mod1", "file": _make_file_dict("mod1", "modified1.txt")}, + {"fileId": "mod2", "file": _make_file_dict("mod2", "modified2.txt")}, + ] + + monkeypatch.setattr( + _mod, "fetch_all_changes", + AsyncMock(return_value=(changes, "new-token", None)), + ) + + change_types = { + "del1": "removed", + "del2": "removed", + "trash1": "trashed", + "mod1": "modified", + "mod2": "modified", + } + monkeypatch.setattr( + _mod, "categorize_change", + lambda change: change_types[change["fileId"]], + ) + + remove_calls: list[str] = [] + + async def _fake_remove(session, file_id, search_space_id): + remove_calls.append(file_id) + + monkeypatch.setattr(_mod, "_remove_document", _fake_remove) + + monkeypatch.setattr( + _mod, "_should_skip_file", + AsyncMock(return_value=(False, None)), + ) + + mock_docs = [MagicMock(), MagicMock()] + download_mock = AsyncMock(return_value=(mock_docs, 0)) + monkeypatch.setattr(_mod, "_download_files_parallel", download_mock) + + batch_mock = AsyncMock(return_value=([], 2, 0)) + pipeline_mock = MagicMock() + pipeline_mock.index_batch_parallel = batch_mock + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock), + ) + monkeypatch.setattr( + _mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()), + ) + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_progress = AsyncMock() + + indexed, skipped = await _index_with_delta_sync( + MagicMock(), + mock_session, + MagicMock(), + _CONNECTOR_ID, + _SEARCH_SPACE_ID, + _USER_ID, + "folder-root", + "start-token-abc", + mock_task_logger, + MagicMock(), + max_files=500, + enable_summary=True, + ) + + assert sorted(remove_calls) == ["del1", "del2", "trash1"] + + download_mock.assert_called_once() + downloaded_files = download_mock.call_args[0][1] + assert len(downloaded_files) == 2 + assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"} + + assert indexed == 2 + assert skipped == 0 From 2f30e48e9070b305d10736e56306d895cfc9464f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 00:06:21 +0530 Subject: [PATCH 09/31] feat: implement async service locking in Google Drive client - Introduced an asyncio lock to the GoogleDriveClient to ensure thread-safe access to the service instance. - Refactored the get_service method to utilize the lock, preventing concurrent attempts to create the service and improving stability in multi-threaded environments. --- .../app/connectors/google_drive/client.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/surfsense_backend/app/connectors/google_drive/client.py b/surfsense_backend/app/connectors/google_drive/client.py index 8cba34d19..697e3e760 100644 --- a/surfsense_backend/app/connectors/google_drive/client.py +++ b/surfsense_backend/app/connectors/google_drive/client.py @@ -1,5 +1,6 @@ """Google Drive API client.""" +import asyncio import io from typing import Any @@ -35,6 +36,7 @@ class GoogleDriveClient: self.connector_id = connector_id self._credentials = credentials self.service = None + self._service_lock = asyncio.Lock() async def get_service(self): """ @@ -49,17 +51,21 @@ class GoogleDriveClient: if self.service: return self.service - try: - if self._credentials: - credentials = self._credentials - else: - credentials = await get_valid_credentials( - self.session, self.connector_id - ) - self.service = build("drive", "v3", credentials=credentials) - return self.service - except Exception as e: - raise Exception(f"Failed to create Google Drive service: {e!s}") from e + async with self._service_lock: + if self.service: + return self.service + + try: + if self._credentials: + credentials = self._credentials + else: + credentials = await get_valid_credentials( + self.session, self.connector_id + ) + self.service = build("drive", "v3", credentials=credentials) + return self.service + except Exception as e: + raise Exception(f"Failed to create Google Drive service: {e!s}") from e async def list_files( self, From 7c7f8b216c14267c76324949efd573d9fac90999 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 00:17:07 +0530 Subject: [PATCH 10/31] feat: implement batch indexing for selected Google Drive files - Introduced `index_google_drive_selected_files` function to enable indexing of multiple user-selected files in parallel, improving efficiency. - Refactored existing indexing logic to handle batch processing, including error handling for individual file failures. - Added unit tests for the new batch indexing functionality, ensuring robustness and proper error collection during the indexing process. --- .../routes/search_source_connectors_routes.py | 22 ++- .../google_drive_indexer.py | 144 ++++++++++++++++++ .../test_google_drive_parallel.py | 122 +++++++++++++++ 3 files changed, 276 insertions(+), 12 deletions(-) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 1ffc6341f..bef2329d8 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -2329,7 +2329,7 @@ async def run_google_drive_indexing( try: from app.tasks.connector_indexers.google_drive_indexer import ( index_google_drive_files, - index_google_drive_single_file, + index_google_drive_selected_files, ) # Parse the structured data @@ -2402,25 +2402,23 @@ async def run_google_drive_indexing( exc_info=True, ) - # Index each individual file - for file in items.files: + # Index all selected files together via the parallel pipeline + if items.files: try: - indexed_count, error_message = await index_google_drive_single_file( + file_tuples = [(f.id, f.name) for f in items.files] + indexed_count, _skipped, file_errors = await index_google_drive_selected_files( session, connector_id, search_space_id, user_id, - file_id=file.id, - file_name=file.name, + files=file_tuples, ) - if error_message: - errors.append(f"File '{file.name}': {error_message}") - else: - total_indexed += indexed_count + total_indexed += indexed_count + errors.extend(file_errors) except Exception as e: - errors.append(f"File '{file.name}': {e!s}") + errors.append(f"File batch indexing: {e!s}") logger.error( - f"Error indexing file {file.name} ({file.id}): {e}", + f"Error batch indexing files: {e}", exc_info=True, ) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 8ba08533f..2d3139343 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -387,6 +387,56 @@ async def _download_and_index( return batch_indexed, download_failed + batch_failed +async def _index_selected_files( + drive_client: GoogleDriveClient, + session: AsyncSession, + file_ids: list[tuple[str, str | None]], + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, + on_heartbeat: HeartbeatCallbackType | None = None, +) -> tuple[int, int, list[str]]: + """Index user-selected files using the parallel pipeline. + + Phase 1 (serial): fetch metadata + skip checks. + Phase 2+3 (parallel): download, ETL, index via _download_and_index. + + Returns (indexed_count, skipped_count, errors). + """ + files_to_download: list[dict] = [] + errors: list[str] = [] + renamed_count = 0 + skipped = 0 + + for file_id, file_name in file_ids: + file, error = await get_file_by_id(drive_client, file_id) + if error or not file: + display = file_name or file_id + errors.append(f"File '{display}': {error or 'File not found'}") + continue + + skip, msg = await _should_skip_file(session, file, search_space_id) + if skip: + if msg and "renamed" in msg.lower(): + renamed_count += 1 + else: + skipped += 1 + continue + + files_to_download.append(file) + + batch_indexed, failed = await _download_and_index( + drive_client, session, files_to_download, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + on_heartbeat=on_heartbeat, + ) + + return renamed_count + batch_indexed, skipped, errors + + # --------------------------------------------------------------------------- # Scan strategies # --------------------------------------------------------------------------- @@ -803,3 +853,97 @@ async def index_google_drive_single_file( await task_logger.log_task_failure(log_entry, "Failed to index Google Drive file", str(e), {"error_type": type(e).__name__}) logger.error(f"Failed to index Google Drive file: {e!s}", exc_info=True) return 0, f"Failed to index Google Drive file: {e!s}" + + +async def index_google_drive_selected_files( + session: AsyncSession, + connector_id: int, + search_space_id: int, + user_id: str, + files: list[tuple[str, str | None]], + on_heartbeat_callback: HeartbeatCallbackType | None = None, +) -> tuple[int, int, list[str]]: + """Index multiple user-selected Google Drive files in parallel. + + Sets up the connector/credentials once, then delegates to + _index_selected_files for the three-phase parallel pipeline. + + Returns (indexed_count, skipped_count, errors). + """ + task_logger = TaskLoggingService(session, search_space_id) + log_entry = await task_logger.log_task_start( + task_name="google_drive_selected_files_indexing", + source="connector_indexing_task", + message=f"Starting Google Drive batch file indexing for {len(files)} files", + metadata={"connector_id": connector_id, "user_id": str(user_id), "file_count": len(files)}, + ) + + try: + connector = None + for ct in ACCEPTED_DRIVE_CONNECTOR_TYPES: + connector = await get_connector_by_id(session, connector_id, ct) + if connector: + break + if not connector: + error_msg = f"Google Drive connector with ID {connector_id} not found" + await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}) + return 0, 0, [error_msg] + + pre_built_credentials = None + if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: + connected_account_id = connector.config.get("composio_connected_account_id") + if not connected_account_id: + error_msg = f"Composio connected_account_id not found for connector {connector_id}" + await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"}) + return 0, 0, [error_msg] + pre_built_credentials = build_composio_credentials(connected_account_id) + else: + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted and not config.SECRET_KEY: + error_msg = "SECRET_KEY not configured but credentials are marked as encrypted" + await task_logger.log_task_failure( + log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"}, + ) + return 0, 0, [error_msg] + + connector_enable_summary = getattr(connector, "enable_summary", True) + drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials) + + indexed, skipped, errors = await _index_selected_files( + drive_client, session, files, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=connector_enable_summary, + on_heartbeat=on_heartbeat_callback, + ) + + await session.commit() + + if errors: + await task_logger.log_task_failure( + log_entry, + f"Batch file indexing completed with {len(errors)} error(s)", + "; ".join(errors), + {"indexed": indexed, "skipped": skipped, "error_count": len(errors)}, + ) + else: + await task_logger.log_task_success( + log_entry, + f"Successfully indexed {indexed} files ({skipped} skipped)", + {"indexed": indexed, "skipped": skipped}, + ) + + logger.info(f"Selected files indexing: {indexed} indexed, {skipped} skipped, {len(errors)} errors") + return indexed, skipped, errors + + except SQLAlchemyError as db_error: + await session.rollback() + error_msg = f"Database error: {db_error!s}" + await task_logger.log_task_failure(log_entry, error_msg, str(db_error), {"error_type": "SQLAlchemyError"}) + logger.error(error_msg, exc_info=True) + return 0, 0, [error_msg] + except Exception as e: + await session.rollback() + error_msg = f"Failed to index Google Drive files: {e!s}" + await task_logger.log_task_failure(log_entry, error_msg, str(e), {"error_type": type(e).__name__}) + logger.error(error_msg, exc_info=True) + return 0, 0, [error_msg] diff --git a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py index 22e900406..1183efa9f 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py @@ -8,6 +8,7 @@ import pytest from app.tasks.connector_indexers.google_drive_indexer import ( _download_files_parallel, _index_full_scan, + _index_selected_files, _index_with_delta_sync, ) @@ -464,3 +465,124 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch): assert indexed == 2 assert skipped == 0 + + +# --------------------------------------------------------------------------- +# _index_selected_files -- parallel indexing of user-selected files +# --------------------------------------------------------------------------- + +@pytest.fixture +def selected_files_mocks(mock_drive_client, monkeypatch): + """Wire up mocks for _index_selected_files tests.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + mock_session = AsyncMock() + + get_file_results: dict[str, tuple[dict | None, str | None]] = {} + + async def _fake_get_file(client, file_id): + return get_file_results.get(file_id, (None, f"Not configured: {file_id}")) + + monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file) + + skip_results: dict[str, tuple[bool, str | None]] = {} + + async def _fake_skip(session, file, search_space_id): + return skip_results.get(file["id"], (False, None)) + + monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip) + + download_and_index_mock = AsyncMock(return_value=(0, 0)) + monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) + + return { + "drive_client": mock_drive_client, + "session": mock_session, + "get_file_results": get_file_results, + "skip_results": skip_results, + "download_and_index_mock": download_and_index_mock, + } + + +async def _run_selected(mocks, file_ids): + return await _index_selected_files( + mocks["drive_client"], + mocks["session"], + file_ids, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + +async def test_selected_files_single_file_indexed(selected_files_mocks): + """Tracer bullet: one file fetched, not skipped, indexed via parallel pipeline.""" + selected_files_mocks["get_file_results"]["f1"] = ( + _make_file_dict("f1", "report.pdf"), + None, + ) + selected_files_mocks["download_and_index_mock"].return_value = (1, 0) + + indexed, skipped, errors = await _run_selected( + selected_files_mocks, [("f1", "report.pdf")], + ) + + assert indexed == 1 + assert skipped == 0 + assert errors == [] + selected_files_mocks["download_and_index_mock"].assert_called_once() + + +async def test_selected_files_fetch_failure_isolation(selected_files_mocks): + """get_file_by_id failing for one file collects an error; others still indexed.""" + selected_files_mocks["get_file_results"]["f1"] = ( + _make_file_dict("f1", "first.txt"), None, + ) + selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404") + selected_files_mocks["get_file_results"]["f3"] = ( + _make_file_dict("f3", "third.txt"), None, + ) + selected_files_mocks["download_and_index_mock"].return_value = (2, 0) + + indexed, skipped, errors = await _run_selected( + selected_files_mocks, + [("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")], + ) + + assert indexed == 2 + assert skipped == 0 + assert len(errors) == 1 + assert "mid.txt" in errors[0] + assert "HTTP 404" in errors[0] + + +async def test_selected_files_skip_rename_counting(selected_files_mocks): + """Unchanged files are skipped, renames counted as indexed, + and only new files are sent to _download_and_index.""" + for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"), + ("n1", "new1.txt"), ("n2", "new2.txt")]: + selected_files_mocks["get_file_results"][fid] = ( + _make_file_dict(fid, fname), None, + ) + + selected_files_mocks["skip_results"]["s1"] = (True, "unchanged") + selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'") + + selected_files_mocks["download_and_index_mock"].return_value = (2, 0) + + indexed, skipped, errors = await _run_selected( + selected_files_mocks, + [("s1", "unchanged.txt"), ("r1", "renamed.txt"), + ("n1", "new1.txt"), ("n2", "new2.txt")], + ) + + assert indexed == 3 # 1 renamed + 2 batch + assert skipped == 1 # 1 unchanged + assert errors == [] + + mock = selected_files_mocks["download_and_index_mock"] + mock.assert_called_once() + call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2] + assert len(call_files) == 2 + assert {f["id"] for f in call_files} == {"n1", "n2"} From da6bbcfe39da73a60e4e2741c1e98e10cb1a88ea Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 08:54:06 +0530 Subject: [PATCH 11/31] feat: add file streaming download functionality to Google Drive client - Introduced `download_file_to_disk` method to stream files directly to disk in chunks, reducing memory usage during downloads. - Updated `download_and_extract_content` function to utilize the new streaming download method for binary files, enhancing efficiency in handling large files. - Improved error handling for download operations, providing clearer feedback on failures. --- .../app/connectors/google_drive/client.py | 25 +++++++++++++++++++ .../google_drive/content_extractor.py | 21 ++++++++++------ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/surfsense_backend/app/connectors/google_drive/client.py b/surfsense_backend/app/connectors/google_drive/client.py index 697e3e760..4e4240e91 100644 --- a/surfsense_backend/app/connectors/google_drive/client.py +++ b/surfsense_backend/app/connectors/google_drive/client.py @@ -172,6 +172,31 @@ class GoogleDriveClient: except Exception as e: return None, f"Error downloading file: {e!s}" + async def download_file_to_disk( + self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024, + ) -> str | None: + """Stream file directly to disk in chunks, avoiding full in-memory buffering. + + Returns error message on failure, None on success. + """ + try: + service = await self.get_service() + request = service.files().get_media(fileId=file_id) + from googleapiclient.http import MediaIoBaseDownload + + with open(dest_path, "wb") as fh: + downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize) + done = False + while not done: + _, done = downloader.next_chunk() + + return None + + except HttpError as e: + return f"HTTP error downloading file: {e.resp.status}" + except Exception as e: + return f"Error downloading file: {e!s}" + async def export_google_file( self, file_id: str, mime_type: str ) -> tuple[bytes | None, str | None]: diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 6fa20bf8e..69f64d9ae 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -60,8 +60,9 @@ async def download_and_extract_content( temp_file_path = None try: - # Download / export if is_google_workspace_file(mime_type): + # Workspace files (Docs/Sheets/Slides) use export -- returns bytes + # in one shot. These are typically small (a few MB as PDF/text). export_mime = get_export_mime_type(mime_type) if not export_mime: return None, drive_metadata, f"Cannot export Google Workspace type: {mime_type}" @@ -69,17 +70,21 @@ async def download_and_extract_content( if error: return None, drive_metadata, error extension = ".pdf" if export_mime == "application/pdf" else ".txt" + + with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp: + tmp.write(content_bytes) + temp_file_path = tmp.name else: - content_bytes, error = await client.download_file(file_id) + # Binary files -- stream directly to disk in chunks to avoid + # loading the entire file into memory. + extension = Path(file_name).suffix or ".bin" + with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp: + temp_file_path = tmp.name + + error = await client.download_file_to_disk(file_id, temp_file_path) if error: return None, drive_metadata, error - extension = Path(file_name).suffix or ".bin" - with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp: - tmp.write(content_bytes) - temp_file_path = tmp.name - - # Parse to markdown markdown = await _parse_file_to_markdown(temp_file_path, file_name) return markdown, drive_metadata, None From db6dd058ddf12a24c6ea890f4dd6727e1b2da6d8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:19:32 +0530 Subject: [PATCH 12/31] feat: migrate Linear and Notion indexers to unified parallel pipeline - Refactored Linear and Notion indexers to utilize the shared IndexingPipelineService for improved document deduplication, summarization, chunking, and embedding with bounded parallel indexing. - Updated the `_build_connector_doc` function in both indexers to create ConnectorDocument instances with enhanced metadata and fallback summaries. - Modified the `index_linear_issues` and `index_notion_pages` functions to return a tuple of (indexed_count, skipped_count, warning_or_error_message) for better error handling and reporting. - Added unit tests for both indexers to validate the new parallel processing logic and ensure correct document creation and indexing behavior. --- .../connector_indexers/linear_indexer.py | 417 +++++------------ .../connector_indexers/notion_indexer.py | 442 +++++------------- .../test_linear_parallel.py | 355 ++++++++++++++ .../test_notion_parallel.py | 345 ++++++++++++++ 4 files changed, 944 insertions(+), 615 deletions(-) create mode 100644 surfsense_backend/tests/unit/connector_indexers/test_linear_parallel.py create mode 100644 surfsense_backend/tests/unit/connector_indexers/test_notion_parallel.py diff --git a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py index 6e9ccaa01..38d931588 100644 --- a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py @@ -1,48 +1,84 @@ """ Linear connector indexer. -Implements 2-phase document status updates for real-time UI feedback: -- Phase 1: Create all documents with 'pending' status (visible in UI immediately) -- Phase 2: Process each document: pending → processing → ready/failed +Uses the shared IndexingPipelineService for document deduplication, +summarization, chunking, and embedding with bounded parallel indexing. """ -import time from collections.abc import Awaitable, Callable -from datetime import datetime from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.linear_connector import LinearConnector -from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType +from app.db import DocumentType, SearchSourceConnectorType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_content_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.document_converters import ( - create_document_chunks, - embed_text, - generate_content_hash, - generate_document_summary, - generate_unique_identifier_hash, -) from .base import ( calculate_date_range, - check_document_by_unique_identifier, check_duplicate_document_by_hash, get_connector_by_id, - get_current_timestamp, logger, - safe_set_chunks, update_connector_last_indexed, ) -# Type hint for heartbeat callback HeartbeatCallbackType = Callable[[int], Awaitable[None]] - -# Heartbeat interval in seconds - update notification every 30 seconds HEARTBEAT_INTERVAL_SECONDS = 30 +def _build_connector_doc( + issue: dict, + formatted_issue: dict, + issue_content: str, + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, +) -> ConnectorDocument: + """Map a raw Linear issue dict to a ConnectorDocument.""" + issue_id = issue.get("id", "") + issue_identifier = issue.get("identifier", "") + issue_title = issue.get("title", "") + state = formatted_issue.get("state", "Unknown") + priority = formatted_issue.get("priority", "Unknown") + comment_count = len(formatted_issue.get("comments", [])) + + metadata = { + "issue_id": issue_id, + "issue_identifier": issue_identifier, + "issue_title": issue_title, + "state": state, + "priority": priority, + "comment_count": comment_count, + "connector_id": connector_id, + "document_type": "Linear Issue", + "connector_type": "Linear", + } + + fallback_summary = ( + f"Linear Issue {issue_identifier}: {issue_title}\n\n" + f"Status: {state}\n\n{issue_content}" + ) + + return ConnectorDocument( + title=f"{issue_identifier}: {issue_title}", + source_markdown=issue_content, + unique_id=issue_id, + document_type=DocumentType.LINEAR_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=enable_summary, + fallback_summary=fallback_summary, + metadata=metadata, + ) + + async def index_linear_issues( session: AsyncSession, connector_id: int, @@ -52,26 +88,15 @@ async def index_linear_issues( end_date: str | None = None, update_last_indexed: bool = True, on_heartbeat_callback: HeartbeatCallbackType | None = None, -) -> tuple[int, str | None]: +) -> tuple[int, int, str | None]: """ Index Linear issues and comments. - Args: - session: Database session - connector_id: ID of the Linear connector - search_space_id: ID of the search space to store documents in - user_id: ID of the user - start_date: Start date for indexing (YYYY-MM-DD format) - end_date: End date for indexing (YYYY-MM-DD format) - update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - on_heartbeat_callback: Optional callback to update notification during long-running indexing. - Returns: - Tuple containing (number of documents indexed, error message or None) + Tuple of (indexed_count, skipped_count, warning_or_error_message) """ task_logger = TaskLoggingService(session, search_space_id) - # Log task start log_entry = await task_logger.log_task_start( task_name="linear_issues_indexing", source="connector_indexing_task", @@ -85,7 +110,7 @@ async def index_linear_issues( ) try: - # Get the connector + # ── Connector lookup ────────────────────────────────────────── await task_logger.log_task_progress( log_entry, f"Retrieving Linear connector {connector_id} from database", @@ -104,11 +129,11 @@ async def index_linear_issues( {"error_type": "ConnectorNotFound"}, ) return ( + 0, 0, f"Connector with ID {connector_id} not found or is not a Linear connector", ) - # Check if access_token exists (support both new OAuth format and old API key format) if not connector.config.get("access_token") and not connector.config.get( "LINEAR_API_KEY" ): @@ -118,26 +143,22 @@ async def index_linear_issues( "Missing Linear access token", {"error_type": "MissingToken"}, ) - return 0, "Linear access token not found in connector config" + return 0, 0, "Linear access token not found in connector config" - # Initialize Linear client with internal refresh capability + # ── Client init ─────────────────────────────────────────────── await task_logger.log_task_progress( log_entry, f"Initializing Linear client for connector {connector_id}", {"stage": "client_initialization"}, ) - # Create connector with session and connector_id for internal refresh - # Token refresh will happen automatically when needed linear_client = LinearConnector(session=session, connector_id=connector_id) - # Handle 'undefined' string from frontend (treat as None) if start_date == "undefined" or start_date == "": start_date = None if end_date == "undefined" or end_date == "": end_date = None - # Calculate date range start_date_str, end_date_str = calculate_date_range( connector, start_date, end_date, default_days_back=365 ) @@ -154,37 +175,34 @@ async def index_linear_issues( }, ) - # Get issues within date range + # ── Fetch issues ────────────────────────────────────────────── try: issues, error = await linear_client.get_issues_by_date_range( - start_date=start_date_str, end_date=end_date_str, include_comments=True + start_date=start_date_str, + end_date=end_date_str, + include_comments=True, ) if error: - # Don't treat "No issues found" as an error that should stop indexing if "No issues found" in error: logger.info(f"No Linear issues found: {error}") - logger.info( - "No issues found is not a critical error, continuing with update" - ) if update_last_indexed: await update_connector_last_indexed( session, connector, update_last_indexed ) await session.commit() - logger.info( - f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found" - ) - return 0, None + return 0, 0, None else: logger.error(f"Failed to get Linear issues: {error}") - return 0, f"Failed to get Linear issues: {error}" + return 0, 0, f"Failed to get Linear issues: {error}" logger.info(f"Retrieved {len(issues)} issues from Linear API") except Exception as e: - logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True) - return 0, f"Failed to get Linear issues: {e!s}" + logger.error( + f"Exception when calling Linear API: {e!s}", exc_info=True + ) + return 0, 0, f"Failed to get Linear issues: {e!s}" if not issues: logger.info("No Linear issues found for the specified date range") @@ -193,19 +211,12 @@ async def index_linear_issues( session, connector, update_last_indexed ) await session.commit() - logger.info( - f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found" - ) - return 0, None # Return None instead of error message when no issues found + return 0, 0, None - # Track the number of documents indexed - documents_indexed = 0 + # ── Build ConnectorDocuments ────────────────────────────────── + connector_docs: list[ConnectorDocument] = [] documents_skipped = 0 - documents_failed = 0 # Track issues that failed processing - skipped_issues = [] - - # Heartbeat tracking - update notification periodically to prevent appearing stuck - last_heartbeat_time = time.time() + duplicate_content_count = 0 await task_logger.log_task_progress( log_entry, @@ -213,13 +224,6 @@ async def index_linear_issues( {"stage": "process_issues", "total_issues": len(issues)}, ) - # ======================================================================= - # PHASE 1: Analyze all issues, create pending documents - # This makes ALL documents visible in the UI immediately with pending status - # ======================================================================= - issues_to_process = [] # List of dicts with document and issue data - new_documents_created = False - for issue in issues: try: issue_id = issue.get("id", "") @@ -230,271 +234,102 @@ async def index_linear_issues( logger.warning( f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}" ) - skipped_issues.append( - f"{issue_identifier or 'Unknown'} (missing data)" - ) documents_skipped += 1 continue - # Format the issue first to get well-structured data formatted_issue = linear_client.format_issue(issue) - - # Convert issue to markdown format - issue_content = linear_client.format_issue_to_markdown(formatted_issue) + issue_content = linear_client.format_issue_to_markdown( + formatted_issue + ) if not issue_content: logger.warning( f"Skipping issue with no content: {issue_identifier} - {issue_title}" ) - skipped_issues.append(f"{issue_identifier} (no content)") documents_skipped += 1 continue - # Generate unique identifier hash for this Linear issue - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.LINEAR_CONNECTOR, issue_id, search_space_id - ) - - # Generate content hash - content_hash = generate_content_hash(issue_content, search_space_id) - - # Check if document with this unique identifier already exists - existing_document = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - - state = formatted_issue.get("state", "Unknown") - description = formatted_issue.get("description", "") - comment_count = len(formatted_issue.get("comments", [])) - priority = formatted_issue.get("priority", "Unknown") - - if existing_document: - # Document exists - check if content has changed - if existing_document.content_hash == content_hash: - # Ensure status is ready (might have been stuck in processing/pending) - if not DocumentStatus.is_state( - existing_document.status, DocumentStatus.READY - ): - existing_document.status = DocumentStatus.ready() - logger.info( - f"Document for Linear issue {issue_identifier} unchanged. Skipping." - ) - documents_skipped += 1 - continue - - # Queue existing document for update (will be set to processing in Phase 2) - issues_to_process.append( - { - "document": existing_document, - "is_new": False, - "issue_content": issue_content, - "content_hash": content_hash, - "issue_id": issue_id, - "issue_identifier": issue_identifier, - "issue_title": issue_title, - "state": state, - "description": description, - "comment_count": comment_count, - "priority": priority, - } - ) - continue - - # Document doesn't exist by unique_identifier_hash - # Check if a document with the same content_hash exists (from another connector) - with session.no_autoflush: - duplicate_by_content = await check_duplicate_document_by_hash( - session, content_hash - ) - - if duplicate_by_content: - logger.info( - f"Linear issue {issue_identifier} already indexed by another connector " - f"(existing document ID: {duplicate_by_content.id}, " - f"type: {duplicate_by_content.document_type}). Skipping." - ) - documents_skipped += 1 - continue - - # Create new document with PENDING status (visible in UI immediately) - document = Document( - search_space_id=search_space_id, - title=f"{issue_identifier}: {issue_title}", - document_type=DocumentType.LINEAR_CONNECTOR, - document_metadata={ - "issue_id": issue_id, - "issue_identifier": issue_identifier, - "issue_title": issue_title, - "state": state, - "comment_count": comment_count, - "connector_id": connector_id, - }, - content="Pending...", # Placeholder until processed - content_hash=unique_identifier_hash, # Temporary unique value - updated when ready - unique_identifier_hash=unique_identifier_hash, - embedding=None, - chunks=[], # Empty at creation - safe for async - status=DocumentStatus.pending(), # Pending until processing starts - updated_at=get_current_timestamp(), - created_by_id=user_id, + doc = _build_connector_doc( + issue, + formatted_issue, + issue_content, connector_id=connector_id, - ) - session.add(document) - new_documents_created = True - - issues_to_process.append( - { - "document": document, - "is_new": True, - "issue_content": issue_content, - "content_hash": content_hash, - "issue_id": issue_id, - "issue_identifier": issue_identifier, - "issue_title": issue_title, - "state": state, - "description": description, - "comment_count": comment_count, - "priority": priority, - } + search_space_id=search_space_id, + user_id=user_id, + enable_summary=connector.enable_summary, ) - except Exception as e: - logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True) - documents_failed += 1 - continue - - # Commit all pending documents - they all appear in UI now - if new_documents_created: - logger.info( - f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents" - ) - await session.commit() - - # ======================================================================= - # PHASE 2: Process each document one by one - # Each document transitions: pending → processing → ready/failed - # ======================================================================= - logger.info(f"Phase 2: Processing {len(issues_to_process)} documents") - - for item in issues_to_process: - # Send heartbeat periodically - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - document = item["document"] - try: - # Set to PROCESSING and commit - shows "processing" in UI for THIS document only - document.status = DocumentStatus.processing() - await session.commit() - - # Heavy processing (LLM, embeddings, chunks) - user_llm = await get_user_long_context_llm( - session, user_id, search_space_id - ) - - if user_llm and connector.enable_summary: - document_metadata_for_summary = { - "issue_id": item["issue_identifier"], - "issue_title": item["issue_title"], - "state": item["state"], - "priority": item["priority"], - "comment_count": item["comment_count"], - "document_type": "Linear Issue", - "connector_type": "Linear", - } - ( - summary_content, - summary_embedding, - ) = await generate_document_summary( - item["issue_content"], user_llm, document_metadata_for_summary + with session.no_autoflush: + duplicate = await check_duplicate_document_by_hash( + session, compute_content_hash(doc) ) - else: - summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n{item['issue_content']}" - summary_embedding = embed_text(summary_content) - - chunks = await create_document_chunks(item["issue_content"]) - - # Update document to READY with actual content - document.title = f"{item['issue_identifier']}: {item['issue_title']}" - document.content = summary_content - document.content_hash = item["content_hash"] - document.embedding = summary_embedding - document.document_metadata = { - "issue_id": item["issue_id"], - "issue_identifier": item["issue_identifier"], - "issue_title": item["issue_title"], - "state": item["state"], - "comment_count": item["comment_count"], - "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "connector_id": connector_id, - } - await safe_set_chunks(session, document, chunks) - document.updated_at = get_current_timestamp() - document.status = DocumentStatus.ready() - - documents_indexed += 1 - - # Batch commit every 10 documents (for ready status updates) - if documents_indexed % 10 == 0: + if duplicate: logger.info( - f"Committing batch: {documents_indexed} Linear issues processed so far" + f"Linear issue {doc.title} already indexed by another connector " + f"(existing document ID: {duplicate.id}, " + f"type: {duplicate.document_type}). Skipping." ) - await session.commit() + duplicate_content_count += 1 + documents_skipped += 1 + continue + + connector_docs.append(doc) except Exception as e: logger.error( - f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}", + f"Error building ConnectorDocument for issue: {e!s}", exc_info=True, ) - # Mark document as failed with reason (visible in UI) - try: - document.status = DocumentStatus.failed(str(e)) - document.updated_at = get_current_timestamp() - except Exception as status_error: - logger.error( - f"Failed to update document status to failed: {status_error}" - ) - skipped_issues.append( - f"{item.get('issue_identifier', 'Unknown')} (processing error)" - ) - documents_failed += 1 + documents_skipped += 1 continue - # CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs + # ── Pipeline: migrate legacy docs + parallel index ──────────── + pipeline = IndexingPipelineService(session) + + await pipeline.migrate_legacy_docs(connector_docs) + + async def _get_llm(s): + return await get_user_long_context_llm(s, user_id, search_space_id) + + _, documents_indexed, documents_failed = await pipeline.index_batch_parallel( + connector_docs, + _get_llm, + max_concurrency=3, + on_heartbeat=on_heartbeat_callback, + heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS, + ) + + # ── Finalize ────────────────────────────────────────────────── await update_connector_last_indexed(session, connector, update_last_indexed) - # Final commit for any remaining documents not yet committed in batches - logger.info(f"Final commit: Total {documents_indexed} Linear issues processed") + logger.info( + f"Final commit: Total {documents_indexed} Linear issues processed" + ) try: await session.commit() logger.info( "Successfully committed all Linear document changes to database" ) except Exception as e: - # Handle any remaining integrity errors gracefully (race conditions, etc.) if ( "duplicate key value violates unique constraint" in str(e).lower() or "uniqueviolationerror" in str(e).lower() ): logger.warning( f"Duplicate content_hash detected during final commit. " - f"This may occur if the same issue was indexed by multiple connectors. " f"Rolling back and continuing. Error: {e!s}" ) await session.rollback() else: raise - # Build warning message if there were issues - warning_parts = [] + warning_parts: list[str] = [] + if duplicate_content_count > 0: + warning_parts.append(f"{duplicate_content_count} duplicate") if documents_failed > 0: warning_parts.append(f"{documents_failed} failed") warning_message = ", ".join(warning_parts) if warning_parts else None - # Log success await task_logger.log_task_success( log_entry, f"Successfully completed Linear indexing for connector {connector_id}", @@ -503,7 +338,7 @@ async def index_linear_issues( "documents_indexed": documents_indexed, "documents_skipped": documents_skipped, "documents_failed": documents_failed, - "skipped_issues_count": len(skipped_issues), + "duplicate_content_count": duplicate_content_count, }, ) @@ -511,7 +346,7 @@ async def index_linear_issues( f"Linear indexing completed: {documents_indexed} ready, " f"{documents_skipped} skipped, {documents_failed} failed" ) - return documents_indexed, warning_message + return documents_indexed, documents_skipped, warning_message except SQLAlchemyError as db_error: await session.rollback() @@ -522,7 +357,7 @@ async def index_linear_issues( {"error_type": "SQLAlchemyError"}, ) logger.error(f"Database error: {db_error!s}", exc_info=True) - return 0, f"Database error: {db_error!s}" + return 0, 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( @@ -532,4 +367,4 @@ async def index_linear_issues( {"error_type": type(e).__name__}, ) logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True) - return 0, f"Failed to index Linear issues: {e!s}" + return 0, 0, f"Failed to index Linear issues: {e!s}" diff --git a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py index 619b8dcd7..6614071a4 100644 --- a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py @@ -1,12 +1,10 @@ """ Notion connector indexer. -Implements real-time document status updates using a two-phase approach: -- Phase 1: Create all documents with PENDING status (visible in UI immediately) -- Phase 2: Process each document one by one (pending → processing → ready/failed) +Uses the shared IndexingPipelineService for document deduplication, +summarization, chunking, and embedding with bounded parallel indexing. """ -import time from collections.abc import Awaitable, Callable from datetime import datetime @@ -14,42 +12,64 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.notion_history import NotionHistoryConnector -from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType +from app.db import DocumentType, SearchSourceConnectorType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_content_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.document_converters import ( - create_document_chunks, - embed_text, - generate_content_hash, - generate_document_summary, - generate_unique_identifier_hash, -) from app.utils.notion_utils import process_blocks from .base import ( - build_document_metadata_string, calculate_date_range, - check_document_by_unique_identifier, check_duplicate_document_by_hash, get_connector_by_id, - get_current_timestamp, logger, - safe_set_chunks, update_connector_last_indexed, ) -# Type alias for retry callback -# Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) -> None RetryCallbackType = Callable[[str, int, int, float], Awaitable[None]] - -# Type alias for heartbeat callback -# Signature: async callback(indexed_count) -> None HeartbeatCallbackType = Callable[[int], Awaitable[None]] - -# Heartbeat interval in seconds - update notification every 30 seconds HEARTBEAT_INTERVAL_SECONDS = 30 +def _build_connector_doc( + page: dict, + markdown_content: str, + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, +) -> ConnectorDocument: + """Map a raw Notion page dict to a ConnectorDocument.""" + page_id = page.get("page_id", "") + page_title = page.get("title", f"Untitled page ({page_id})") + + metadata = { + "page_title": page_title, + "page_id": page_id, + "connector_id": connector_id, + "document_type": "Notion Page", + "connector_type": "Notion", + } + + fallback_summary = f"Notion Page: {page_title}\n\n{markdown_content}" + + return ConnectorDocument( + title=page_title, + source_markdown=markdown_content, + unique_id=page_id, + document_type=DocumentType.NOTION_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=enable_summary, + fallback_summary=fallback_summary, + metadata=metadata, + ) + + async def index_notion_pages( session: AsyncSession, connector_id: int, @@ -60,30 +80,15 @@ async def index_notion_pages( update_last_indexed: bool = True, on_retry_callback: RetryCallbackType | None = None, on_heartbeat_callback: HeartbeatCallbackType | None = None, -) -> tuple[int, str | None]: +) -> tuple[int, int, str | None]: """ Index Notion pages from all accessible pages. - Args: - session: Database session - connector_id: ID of the Notion connector - search_space_id: ID of the search space to store documents in - user_id: ID of the user - start_date: Start date for indexing (YYYY-MM-DD format) - end_date: End date for indexing (YYYY-MM-DD format) - update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - on_retry_callback: Optional callback for retry progress notifications. - Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) - retry_reason is one of: 'rate_limit', 'server_error', 'timeout' - on_heartbeat_callback: Optional callback to update notification during long-running indexing. - Called periodically with (indexed_count) to prevent task appearing stuck. - Returns: - Tuple containing (number of documents indexed, error message or None) + Tuple of (indexed_count, skipped_count, warning_or_error_message) """ task_logger = TaskLoggingService(session, search_space_id) - # Log task start log_entry = await task_logger.log_task_start( task_name="notion_pages_indexing", source="connector_indexing_task", @@ -97,7 +102,7 @@ async def index_notion_pages( ) try: - # Get the connector + # ── Connector lookup ────────────────────────────────────────── await task_logger.log_task_progress( log_entry, f"Retrieving Notion connector {connector_id} from database", @@ -116,11 +121,11 @@ async def index_notion_pages( {"error_type": "ConnectorNotFound"}, ) return ( + 0, 0, f"Connector with ID {connector_id} not found or is not a Notion connector", ) - # Check if access_token exists (support both new OAuth format and old integration token format) if not connector.config.get("access_token") and not connector.config.get( "NOTION_INTEGRATION_TOKEN" ): @@ -130,9 +135,9 @@ async def index_notion_pages( "Missing Notion access token", {"error_type": "MissingToken"}, ) - return 0, "Notion access token not found in connector config" + return 0, 0, "Notion access token not found in connector config" - # Initialize Notion client with internal refresh capability + # ── Client init ─────────────────────────────────────────────── await task_logger.log_task_progress( log_entry, f"Initializing Notion client for connector {connector_id}", @@ -141,18 +146,15 @@ async def index_notion_pages( logger.info(f"Initializing Notion client for connector {connector_id}") - # Handle 'undefined' string from frontend (treat as None) if start_date == "undefined" or start_date == "": start_date = None if end_date == "undefined" or end_date == "": end_date = None - # Calculate date range using the shared utility function start_date_str, end_date_str = calculate_date_range( connector, start_date, end_date, default_days_back=365 ) - # Convert YYYY-MM-DD to ISO format for Notion API start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime( "%Y-%m-%dT%H:%M:%SZ" ) @@ -160,13 +162,10 @@ async def index_notion_pages( "%Y-%m-%dT%H:%M:%SZ" ) - # Create connector with session and connector_id for internal refresh - # Token refresh will happen automatically when needed notion_client = NotionHistoryConnector( session=session, connector_id=connector_id ) - # Set retry callback if provided (for user notifications during rate limits) if on_retry_callback: notion_client.set_retry_callback(on_retry_callback) @@ -182,21 +181,19 @@ async def index_notion_pages( }, ) - # Get all pages + # ── Fetch pages ─────────────────────────────────────────────── try: pages = await notion_client.get_all_pages( start_date=start_date_iso, end_date=end_date_iso ) logger.info(f"Found {len(pages)} Notion pages") - # Get count of pages that had unsupported content skipped pages_with_skipped_content = notion_client.get_skipped_content_count() if pages_with_skipped_content > 0: logger.info( f"{pages_with_skipped_content} pages had Notion AI content skipped (not available via API)" ) - # Check if using legacy integration token and log warning if notion_client.is_using_legacy_token(): logger.warning( f"Connector {connector_id} is using legacy integration token. " @@ -204,8 +201,6 @@ async def index_notion_pages( ) except Exception as e: error_str = str(e) - # Check if this is an unsupported block type error (transcription, ai_block, etc.) - # These are known Notion API limitations and should be logged as warnings, not errors unsupported_block_errors = [ "transcription is not supported", "ai_block is not supported", @@ -216,7 +211,6 @@ async def index_notion_pages( ) if is_unsupported_block_error: - # Log as warning since this is a known Notion API limitation logger.warning( f"Notion API limitation for connector {connector_id}: {error_str}. " "This is a known issue with Notion AI blocks (transcription, ai_block) " @@ -229,7 +223,6 @@ async def index_notion_pages( {"error_type": "UnsupportedBlockType", "is_known_limitation": True}, ) else: - # Log as error for other failures logger.error( f"Error fetching Notion pages for connector {connector_id}: {error_str}", exc_info=True, @@ -242,7 +235,7 @@ async def index_notion_pages( ) await notion_client.close() - return 0, f"Failed to get Notion pages: {e!s}" + return 0, 0, f"Failed to get Notion pages: {e!s}" if not pages: await task_logger.log_task_success( @@ -252,21 +245,17 @@ async def index_notion_pages( {"pages_found": 0}, ) logger.info("No Notion pages found to index") - # CRITICAL: Update timestamp even when no pages found so Zero syncs - await update_connector_last_indexed(session, connector, update_last_indexed) + await update_connector_last_indexed( + session, connector, update_last_indexed + ) await session.commit() await notion_client.close() - return 0, None # Success with 0 pages, not an error + return 0, 0, None - # Track the number of documents indexed - documents_indexed = 0 + # ── Build ConnectorDocuments ────────────────────────────────── + connector_docs: list[ConnectorDocument] = [] documents_skipped = 0 - documents_failed = 0 duplicate_content_count = 0 - skipped_pages = [] - - # Heartbeat tracking - update notification periodically to prevent appearing stuck - last_heartbeat_time = time.time() await task_logger.log_task_progress( log_entry, @@ -274,13 +263,6 @@ async def index_notion_pages( {"stage": "process_pages", "total_pages": len(pages)}, ) - # ======================================================================= - # PHASE 1: Analyze all pages, create pending documents - # This makes ALL documents visible in the UI immediately with pending status - # ======================================================================= - pages_to_process = [] # List of dicts with document and page data - new_documents_created = False - for page in pages: try: page_id = page.get("page_id") @@ -293,225 +275,71 @@ async def index_notion_pages( if not page_content: logger.info(f"No content found in page {page_title}. Skipping.") - skipped_pages.append(f"{page_title} (no content)") documents_skipped += 1 continue - # Convert page content to markdown format markdown_content = f"# Notion Page: {page_title}\n\n" markdown_content += process_blocks(page_content) - # Format document metadata - metadata_sections = [ - ("METADATA", [f"PAGE_TITLE: {page_title}", f"PAGE_ID: {page_id}"]), - ( - "CONTENT", - [ - "FORMAT: markdown", - "TEXT_START", - markdown_content, - "TEXT_END", - ], - ), - ] - - # Build the document string - combined_document_string = build_document_metadata_string( - metadata_sections - ) - - # Generate unique identifier hash for this Notion page - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.NOTION_CONNECTOR, page_id, search_space_id - ) - - # Generate content hash - content_hash = generate_content_hash( - combined_document_string, search_space_id - ) - - # Check if document with this unique identifier already exists - existing_document = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - - if existing_document: - # Document exists - check if content has changed - if existing_document.content_hash == content_hash: - # Ensure status is ready (might have been stuck in processing/pending) - if not DocumentStatus.is_state( - existing_document.status, DocumentStatus.READY - ): - existing_document.status = DocumentStatus.ready() - documents_skipped += 1 - continue - - # Queue existing document for update (will be set to processing in Phase 2) - pages_to_process.append( - { - "document": existing_document, - "is_new": False, - "markdown_content": markdown_content, - "content_hash": content_hash, - "page_id": page_id, - "page_title": page_title, - } + if not markdown_content.strip(): + logger.warning( + f"Skipping page with empty markdown: {page_title}" ) + documents_skipped += 1 continue - # Document doesn't exist by unique_identifier_hash - # Check if a document with the same content_hash exists (from another connector) - with session.no_autoflush: - duplicate_by_content = await check_duplicate_document_by_hash( - session, content_hash - ) + doc = _build_connector_doc( + page, + markdown_content, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=connector.enable_summary, + ) - if duplicate_by_content: + with session.no_autoflush: + duplicate = await check_duplicate_document_by_hash( + session, compute_content_hash(doc) + ) + if duplicate: logger.info( - f"Notion page {page_title} already indexed by another connector " - f"(existing document ID: {duplicate_by_content.id}, " - f"type: {duplicate_by_content.document_type}). Skipping." + f"Notion page {doc.title} already indexed by another connector " + f"(existing document ID: {duplicate.id}, " + f"type: {duplicate.document_type}). Skipping." ) duplicate_content_count += 1 documents_skipped += 1 continue - # Create new document with PENDING status (visible in UI immediately) - document = Document( - search_space_id=search_space_id, - title=page_title, - document_type=DocumentType.NOTION_CONNECTOR, - document_metadata={ - "page_title": page_title, - "page_id": page_id, - "connector_id": connector_id, - }, - content="Pending...", # Placeholder until processed - content_hash=unique_identifier_hash, # Temporary unique value - updated when ready - unique_identifier_hash=unique_identifier_hash, - embedding=None, - chunks=[], # Empty at creation - safe for async - status=DocumentStatus.pending(), # Pending until processing starts - updated_at=get_current_timestamp(), - created_by_id=user_id, - connector_id=connector_id, - ) - session.add(document) - new_documents_created = True - - pages_to_process.append( - { - "document": document, - "is_new": True, - "markdown_content": markdown_content, - "content_hash": content_hash, - "page_id": page_id, - "page_title": page_title, - } - ) + connector_docs.append(doc) except Exception as e: - logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True) - documents_failed += 1 - continue - - # Commit all pending documents - they all appear in UI now - if new_documents_created: - logger.info( - f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents" - ) - await session.commit() - - # ======================================================================= - # PHASE 2: Process each document one by one - # Each document transitions: pending → processing → ready/failed - # ======================================================================= - logger.info(f"Phase 2: Processing {len(pages_to_process)} documents") - - for item in pages_to_process: - # Send heartbeat periodically - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - document = item["document"] - try: - # Set to PROCESSING and commit - shows "processing" in UI for THIS document only - document.status = DocumentStatus.processing() - await session.commit() - - # Heavy processing (LLM, embeddings, chunks) - user_llm = await get_user_long_context_llm( - session, user_id, search_space_id + logger.error( + f"Error building ConnectorDocument for page: {e!s}", + exc_info=True, ) - - if user_llm and connector.enable_summary: - document_metadata_for_summary = { - "page_title": item["page_title"], - "page_id": item["page_id"], - "document_type": "Notion Page", - "connector_type": "Notion", - } - ( - summary_content, - summary_embedding, - ) = await generate_document_summary( - item["markdown_content"], - user_llm, - document_metadata_for_summary, - ) - else: - summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content']}" - summary_embedding = embed_text(summary_content) - - chunks = await create_document_chunks(item["markdown_content"]) - - # Update document to READY with actual content - document.title = item["page_title"] - document.content = summary_content - document.content_hash = item["content_hash"] - document.embedding = summary_embedding - document.document_metadata = { - "page_title": item["page_title"], - "page_id": item["page_id"], - "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "connector_id": connector_id, - } - await safe_set_chunks(session, document, chunks) - document.updated_at = get_current_timestamp() - document.status = DocumentStatus.ready() - - documents_indexed += 1 - - # Batch commit every 10 documents (for ready status updates) - if documents_indexed % 10 == 0: - logger.info( - f"Committing batch: {documents_indexed} Notion pages processed so far" - ) - await session.commit() - - except Exception as e: - logger.error(f"Error processing Notion page: {e!s}", exc_info=True) - # Mark document as failed with reason (visible in UI) - try: - document.status = DocumentStatus.failed(str(e)) - document.updated_at = get_current_timestamp() - except Exception as status_error: - logger.error( - f"Failed to update document status to failed: {status_error}" - ) - skipped_pages.append(f"{item['page_title']} (processing error)") - documents_failed += 1 + documents_skipped += 1 continue - # CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs + # ── Pipeline: migrate legacy docs + parallel index ──────────── + pipeline = IndexingPipelineService(session) + + await pipeline.migrate_legacy_docs(connector_docs) + + async def _get_llm(s): + return await get_user_long_context_llm(s, user_id, search_space_id) + + _, documents_indexed, documents_failed = await pipeline.index_batch_parallel( + connector_docs, + _get_llm, + max_concurrency=3, + on_heartbeat=on_heartbeat_callback, + heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS, + ) + + # ── Finalize ────────────────────────────────────────────────── await update_connector_last_indexed(session, connector, update_last_indexed) - total_processed = documents_indexed - - # Final commit to ensure all documents are persisted (safety net) logger.info(f"Final commit: Total {documents_indexed} documents processed") try: await session.commit() @@ -519,59 +347,53 @@ async def index_notion_pages( "Successfully committed all Notion document changes to database" ) except Exception as e: - # Handle any remaining integrity errors gracefully (race conditions, etc.) if ( "duplicate key value violates unique constraint" in str(e).lower() or "uniqueviolationerror" in str(e).lower() ): logger.warning( f"Duplicate content_hash detected during final commit. " - f"This may occur if the same page was indexed by multiple connectors. " f"Rolling back and continuing. Error: {e!s}" ) await session.rollback() - # Don't fail the entire task - some documents may have been successfully indexed else: raise - # Get final count of pages with skipped Notion AI content + # ── Build warning / notification message ────────────────────── pages_with_skipped_ai_content = notion_client.get_skipped_content_count() - # Build warning message if there were issues - warning_parts = [] + warning_parts: list[str] = [] if duplicate_content_count > 0: warning_parts.append(f"{duplicate_content_count} duplicate") if documents_failed > 0: warning_parts.append(f"{documents_failed} failed") - warning_message = ", ".join(warning_parts) if warning_parts else None - # Prepare result message with user-friendly notification about skipped content - result_message = None - if skipped_pages: - result_message = f"Processed {total_processed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}" - else: - result_message = f"Processed {total_processed} pages." - - # Add user-friendly message about skipped Notion AI content + notification_parts: list[str] = [] if pages_with_skipped_ai_content > 0: - result_message += ( - " Audio transcriptions and AI summaries from Notion aren't accessible " - "via their API - all other content was saved." + notification_parts.append( + "Some Notion AI content couldn't be synced (API limitation)" ) + if notion_client.is_using_legacy_token(): + notification_parts.append( + "Using legacy token. Reconnect with OAuth for better reliability." + ) + if warning_parts: + notification_parts.append(", ".join(warning_parts)) + + user_notification_message = ( + " ".join(notification_parts) if notification_parts else None + ) - # Log success await task_logger.log_task_success( log_entry, f"Successfully completed Notion indexing for connector {connector_id}", { - "pages_processed": total_processed, + "pages_processed": documents_indexed, "documents_indexed": documents_indexed, "documents_skipped": documents_skipped, "documents_failed": documents_failed, "duplicate_content_count": duplicate_content_count, - "skipped_pages_count": len(skipped_pages), "pages_with_skipped_ai_content": pages_with_skipped_ai_content, - "result_message": result_message, }, ) @@ -581,35 +403,9 @@ async def index_notion_pages( f"({duplicate_content_count} duplicate content)" ) - # Clean up the async client await notion_client.close() - # Build user-friendly notification messages - # This will be shown in the notification to inform users - notification_parts = [] - - if pages_with_skipped_ai_content > 0: - notification_parts.append( - "Some Notion AI content couldn't be synced (API limitation)" - ) - - if notion_client.is_using_legacy_token(): - notification_parts.append( - "Using legacy token. Reconnect with OAuth for better reliability." - ) - - # Include warning message if there were issues - if warning_message: - notification_parts.append(warning_message) - - user_notification_message = ( - " ".join(notification_parts) if notification_parts else None - ) - - return ( - total_processed, - user_notification_message, - ) + return documents_indexed, documents_skipped, user_notification_message except SQLAlchemyError as db_error: await session.rollback() @@ -622,10 +418,9 @@ async def index_notion_pages( logger.error( f"Database error during Notion indexing: {db_error!s}", exc_info=True ) - # Clean up the async client in case of error if "notion_client" in locals(): await notion_client.close() - return 0, f"Database error: {db_error!s}" + return 0, 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( @@ -635,7 +430,6 @@ async def index_notion_pages( {"error_type": type(e).__name__}, ) logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True) - # Clean up the async client in case of error if "notion_client" in locals(): await notion_client.close() - return 0, f"Failed to index Notion pages: {e!s}" + return 0, 0, f"Failed to index Notion pages: {e!s}" diff --git a/surfsense_backend/tests/unit/connector_indexers/test_linear_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_linear_parallel.py new file mode 100644 index 000000000..b0ea48644 --- /dev/null +++ b/surfsense_backend/tests/unit/connector_indexers/test_linear_parallel.py @@ -0,0 +1,355 @@ +"""Tests for Linear indexer migrated to the unified parallel pipeline.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import app.tasks.connector_indexers.linear_indexer as _mod +from app.db import DocumentType +from app.tasks.connector_indexers.linear_indexer import ( + _build_connector_doc, + index_linear_issues, +) + +pytestmark = pytest.mark.unit + +_USER_ID = "00000000-0000-0000-0000-000000000001" +_CONNECTOR_ID = 42 +_SEARCH_SPACE_ID = 1 + + +def _make_issue( + issue_id: str = "issue-1", + identifier: str = "ENG-1", + title: str = "Fix bug", +): + return {"id": issue_id, "identifier": identifier, "title": title} + + +def _make_formatted_issue( + issue_id: str = "issue-1", + identifier: str = "ENG-1", + title: str = "Fix bug", + state: str = "In Progress", + priority: str = "High", + comments=None, +): + return { + "id": issue_id, + "identifier": identifier, + "title": title, + "state": state, + "priority": priority, + "description": "Some description", + "comments": comments or [], + } + + +# --------------------------------------------------------------------------- +# Slice 1: _build_connector_doc tracer bullet +# --------------------------------------------------------------------------- + + +async def test_build_connector_doc_produces_correct_fields(): + """Tracer bullet: a Linear issue produces a ConnectorDocument with correct fields.""" + issue = _make_issue(issue_id="abc-123", identifier="ENG-42", title="Fix login bug") + formatted = _make_formatted_issue( + issue_id="abc-123", + identifier="ENG-42", + title="Fix login bug", + state="Done", + priority="Urgent", + comments=[{"id": "c1"}], + ) + markdown = "# ENG-42: Fix login bug\n\nDescription here" + + doc = _build_connector_doc( + issue, + formatted, + markdown, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + assert doc.title == "ENG-42: Fix login bug" + assert doc.unique_id == "abc-123" + assert doc.document_type == DocumentType.LINEAR_CONNECTOR + assert doc.source_markdown == markdown + assert doc.search_space_id == _SEARCH_SPACE_ID + assert doc.connector_id == _CONNECTOR_ID + assert doc.created_by_id == _USER_ID + assert doc.should_summarize is True + assert doc.metadata["issue_id"] == "abc-123" + assert doc.metadata["issue_identifier"] == "ENG-42" + assert doc.metadata["issue_title"] == "Fix login bug" + assert doc.metadata["state"] == "Done" + assert doc.metadata["priority"] == "Urgent" + assert doc.metadata["comment_count"] == 1 + assert doc.metadata["connector_id"] == _CONNECTOR_ID + assert doc.metadata["document_type"] == "Linear Issue" + assert doc.metadata["connector_type"] == "Linear" + assert doc.fallback_summary is not None + assert "ENG-42" in doc.fallback_summary + assert markdown in doc.fallback_summary + + +async def test_build_connector_doc_summary_disabled(): + """When enable_summary is False, should_summarize is False.""" + doc = _build_connector_doc( + _make_issue(), + _make_formatted_issue(), + "# content", + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=False, + ) + + assert doc.should_summarize is False + + +# --------------------------------------------------------------------------- +# Shared fixtures for Slices 2-6 +# --------------------------------------------------------------------------- + + +def _mock_connector(enable_summary: bool = True): + c = MagicMock() + c.config = {"access_token": "tok"} + c.enable_summary = enable_summary + c.last_indexed_at = None + return c + + +def _mock_linear_client(issues=None, error=None): + client = MagicMock() + client.get_issues_by_date_range = AsyncMock( + return_value=(issues if issues is not None else [], error), + ) + client.format_issue = MagicMock(side_effect=lambda i: _make_formatted_issue( + issue_id=i.get("id", ""), + identifier=i.get("identifier", ""), + title=i.get("title", ""), + )) + client.format_issue_to_markdown = MagicMock( + side_effect=lambda fi: f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent" + ) + return client + + +@pytest.fixture +def linear_mocks(monkeypatch): + """Wire up all external boundary mocks for index_linear_issues.""" + mock_session = AsyncMock() + mock_session.no_autoflush = MagicMock() + + mock_connector = _mock_connector() + monkeypatch.setattr( + _mod, "get_connector_by_id", AsyncMock(return_value=mock_connector), + ) + + linear_client = _mock_linear_client(issues=[_make_issue()]) + monkeypatch.setattr( + _mod, "LinearConnector", MagicMock(return_value=linear_client), + ) + + monkeypatch.setattr( + _mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None), + ) + + monkeypatch.setattr( + _mod, "update_connector_last_indexed", AsyncMock(), + ) + + monkeypatch.setattr( + _mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")), + ) + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_progress = AsyncMock() + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_failure = AsyncMock() + monkeypatch.setattr( + _mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger), + ) + + batch_mock = AsyncMock(return_value=([], 1, 0)) + pipeline_mock = MagicMock() + pipeline_mock.index_batch_parallel = batch_mock + pipeline_mock.migrate_legacy_docs = AsyncMock() + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock), + ) + + return { + "session": mock_session, + "connector": mock_connector, + "linear_client": linear_client, + "task_logger": mock_task_logger, + "pipeline_mock": pipeline_mock, + "batch_mock": batch_mock, + } + + +async def _run_index(mocks, **overrides): + return await index_linear_issues( + session=mocks["session"], + connector_id=overrides.get("connector_id", _CONNECTOR_ID), + search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID), + user_id=overrides.get("user_id", _USER_ID), + start_date=overrides.get("start_date", "2025-01-01"), + end_date=overrides.get("end_date", "2025-12-31"), + update_last_indexed=overrides.get("update_last_indexed", True), + on_heartbeat_callback=overrides.get("on_heartbeat_callback"), + ) + + +# --------------------------------------------------------------------------- +# Slice 2: Full pipeline wiring +# --------------------------------------------------------------------------- + + +async def test_one_issue_calls_pipeline_and_returns_indexed_count(linear_mocks): + """One valid issue is passed to the pipeline and the indexed count is returned.""" + indexed, skipped, warning = await _run_index(linear_mocks) + + assert indexed == 1 + assert skipped == 0 + assert warning is None + + linear_mocks["batch_mock"].assert_called_once() + call_args = linear_mocks["batch_mock"].call_args + connector_docs = call_args[0][0] + assert len(connector_docs) == 1 + assert connector_docs[0].document_type == DocumentType.LINEAR_CONNECTOR + + +async def test_pipeline_called_with_max_concurrency_3(linear_mocks): + """index_batch_parallel is called with max_concurrency=3.""" + await _run_index(linear_mocks) + + call_kwargs = linear_mocks["batch_mock"].call_args[1] + assert call_kwargs.get("max_concurrency") == 3 + + +async def test_migrate_legacy_docs_called_before_indexing(linear_mocks): + """migrate_legacy_docs is called on the pipeline before index_batch_parallel.""" + await _run_index(linear_mocks) + + linear_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once() + + +# --------------------------------------------------------------------------- +# Slice 3: Issue skipping (missing ID / title) +# --------------------------------------------------------------------------- + + +async def test_issues_with_missing_id_are_skipped(linear_mocks): + """Issues without id are skipped and not passed to the pipeline.""" + issues = [ + _make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"), + {"id": "", "identifier": "ENG-2", "title": "No ID"}, + ] + linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None) + + indexed, skipped, _ = await _run_index(linear_mocks) + + connector_docs = linear_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert connector_docs[0].unique_id == "valid-1" + assert skipped == 1 + + +async def test_issues_with_missing_title_are_skipped(linear_mocks): + """Issues without title are skipped.""" + issues = [ + _make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"), + {"id": "id-2", "identifier": "ENG-2", "title": ""}, + ] + linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None) + + indexed, skipped, _ = await _run_index(linear_mocks) + + connector_docs = linear_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# Slice 4: Duplicate content skipping +# --------------------------------------------------------------------------- + + +async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch): + """Issues whose content hash matches an existing document are skipped.""" + issues = [ + _make_issue(issue_id="new-1", identifier="ENG-1", title="New"), + _make_issue(issue_id="dup-1", identifier="ENG-2", title="Dup"), + ] + linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None) + + call_count = 0 + + async def _check_dup(session, content_hash): + nonlocal call_count + call_count += 1 + if call_count == 2: + dup = MagicMock() + dup.id = 99 + dup.document_type = "OTHER" + return dup + return None + + monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup) + + indexed, skipped, _ = await _run_index(linear_mocks) + + connector_docs = linear_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# Slice 5: Heartbeat callback forwarding +# --------------------------------------------------------------------------- + + +async def test_heartbeat_callback_forwarded_to_pipeline(linear_mocks): + """on_heartbeat_callback is passed through to index_batch_parallel.""" + heartbeat_cb = AsyncMock() + + await _run_index(linear_mocks, on_heartbeat_callback=heartbeat_cb) + + call_kwargs = linear_mocks["batch_mock"].call_args[1] + assert call_kwargs.get("on_heartbeat") is heartbeat_cb + + +# --------------------------------------------------------------------------- +# Slice 6: Empty issues early return +# --------------------------------------------------------------------------- + + +async def test_empty_issues_returns_zero_tuple(linear_mocks): + """When no issues are found, returns (0, 0, None) and pipeline is not called.""" + linear_mocks["linear_client"].get_issues_by_date_range.return_value = ([], None) + + indexed, skipped, warning = await _run_index(linear_mocks) + + assert indexed == 0 + assert skipped == 0 + assert warning is None + + linear_mocks["batch_mock"].assert_not_called() + + +async def test_failed_docs_warning_in_result(linear_mocks): + """When documents fail indexing, the warning includes the count.""" + linear_mocks["batch_mock"].return_value = ([], 0, 2) + + _, _, warning = await _run_index(linear_mocks) + + assert warning is not None + assert "2 failed" in warning diff --git a/surfsense_backend/tests/unit/connector_indexers/test_notion_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_notion_parallel.py new file mode 100644 index 000000000..99fb8bad7 --- /dev/null +++ b/surfsense_backend/tests/unit/connector_indexers/test_notion_parallel.py @@ -0,0 +1,345 @@ +"""Tests for Notion indexer migrated to the unified parallel pipeline.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import app.tasks.connector_indexers.notion_indexer as _mod +from app.db import DocumentType +from app.tasks.connector_indexers.notion_indexer import ( + _build_connector_doc, + index_notion_pages, +) + +pytestmark = pytest.mark.unit + +_USER_ID = "00000000-0000-0000-0000-000000000001" +_CONNECTOR_ID = 42 +_SEARCH_SPACE_ID = 1 + + +def _make_page(page_id: str = "page-1", title: str = "Test Page", content=None): + if content is None: + content = [{"type": "paragraph", "content": "Hello world", "children": []}] + return {"page_id": page_id, "title": title, "content": content} + + +# --------------------------------------------------------------------------- +# Slice 1: _build_connector_doc tracer bullet +# --------------------------------------------------------------------------- + + +async def test_build_connector_doc_produces_correct_fields(): + """Tracer bullet: a single Notion page produces a ConnectorDocument with correct fields.""" + + page = _make_page(page_id="abc-123", title="My Notion Page") + markdown = "# My Notion Page\n\nHello world" + + doc = _build_connector_doc( + page, + markdown, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + assert doc.title == "My Notion Page" + assert doc.unique_id == "abc-123" + assert doc.document_type == DocumentType.NOTION_CONNECTOR + assert doc.source_markdown == markdown + assert doc.search_space_id == _SEARCH_SPACE_ID + assert doc.connector_id == _CONNECTOR_ID + assert doc.created_by_id == _USER_ID + assert doc.should_summarize is True + assert doc.metadata["page_title"] == "My Notion Page" + assert doc.metadata["page_id"] == "abc-123" + assert doc.metadata["connector_id"] == _CONNECTOR_ID + assert doc.metadata["document_type"] == "Notion Page" + assert doc.metadata["connector_type"] == "Notion" + assert doc.fallback_summary is not None + assert "My Notion Page" in doc.fallback_summary + assert markdown in doc.fallback_summary + + +async def test_build_connector_doc_summary_disabled(): + """When enable_summary is False, should_summarize is False.""" + doc = _build_connector_doc( + _make_page(), + "# content", + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=False, + ) + + assert doc.should_summarize is False + + +# --------------------------------------------------------------------------- +# Shared fixtures for Slices 2-7 (full index_notion_pages tests) +# --------------------------------------------------------------------------- + + +def _mock_connector(enable_summary: bool = True): + c = MagicMock() + c.config = {"access_token": "tok"} + c.enable_summary = enable_summary + c.last_indexed_at = None + return c + + +def _mock_notion_client(pages=None, skipped_count=0, legacy_token=False): + client = MagicMock() + client.get_all_pages = AsyncMock(return_value=pages if pages is not None else []) + client.get_skipped_content_count = MagicMock(return_value=skipped_count) + client.is_using_legacy_token = MagicMock(return_value=legacy_token) + client.close = AsyncMock() + client.set_retry_callback = MagicMock() + return client + + +@pytest.fixture +def notion_mocks(monkeypatch): + """Wire up all external boundary mocks for index_notion_pages.""" + mock_session = AsyncMock() + mock_session.no_autoflush = MagicMock() + + mock_connector = _mock_connector() + monkeypatch.setattr( + _mod, "get_connector_by_id", AsyncMock(return_value=mock_connector), + ) + + notion_client = _mock_notion_client(pages=[_make_page()]) + monkeypatch.setattr( + _mod, "NotionHistoryConnector", MagicMock(return_value=notion_client), + ) + + monkeypatch.setattr( + _mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None), + ) + + monkeypatch.setattr( + _mod, "update_connector_last_indexed", AsyncMock(), + ) + + monkeypatch.setattr( + _mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")), + ) + + monkeypatch.setattr( + _mod, "process_blocks", MagicMock(return_value="Converted markdown content"), + ) + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_progress = AsyncMock() + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_failure = AsyncMock() + monkeypatch.setattr( + _mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger), + ) + + batch_mock = AsyncMock(return_value=([], 1, 0)) + pipeline_mock = MagicMock() + pipeline_mock.index_batch_parallel = batch_mock + pipeline_mock.migrate_legacy_docs = AsyncMock() + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock), + ) + + return { + "session": mock_session, + "connector": mock_connector, + "notion_client": notion_client, + "task_logger": mock_task_logger, + "pipeline_mock": pipeline_mock, + "batch_mock": batch_mock, + } + + +async def _run_index(mocks, **overrides): + return await index_notion_pages( + session=mocks["session"], + connector_id=overrides.get("connector_id", _CONNECTOR_ID), + search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID), + user_id=overrides.get("user_id", _USER_ID), + start_date=overrides.get("start_date", "2025-01-01"), + end_date=overrides.get("end_date", "2025-12-31"), + update_last_indexed=overrides.get("update_last_indexed", True), + on_retry_callback=overrides.get("on_retry_callback"), + on_heartbeat_callback=overrides.get("on_heartbeat_callback"), + ) + + +# --------------------------------------------------------------------------- +# Slice 2: Full pipeline wiring +# --------------------------------------------------------------------------- + + +async def test_one_page_calls_pipeline_and_returns_indexed_count(notion_mocks): + """One valid page is passed to the pipeline and the indexed count is returned.""" + indexed, skipped, warning = await _run_index(notion_mocks) + + assert indexed == 1 + assert skipped == 0 + assert warning is None + + notion_mocks["batch_mock"].assert_called_once() + call_args = notion_mocks["batch_mock"].call_args + connector_docs = call_args[0][0] + assert len(connector_docs) == 1 + assert connector_docs[0].document_type == DocumentType.NOTION_CONNECTOR + + +async def test_pipeline_called_with_max_concurrency_3(notion_mocks): + """index_batch_parallel is called with max_concurrency=3.""" + await _run_index(notion_mocks) + + call_kwargs = notion_mocks["batch_mock"].call_args[1] + assert call_kwargs.get("max_concurrency") == 3 + + +async def test_migrate_legacy_docs_called_before_indexing(notion_mocks): + """migrate_legacy_docs is called on the pipeline before index_batch_parallel.""" + await _run_index(notion_mocks) + + notion_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once() + + +# --------------------------------------------------------------------------- +# Slice 3: Page skipping (no content / missing ID) +# --------------------------------------------------------------------------- + + +async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch): + """Pages without page_id are skipped and not passed to the pipeline.""" + pages = [ + _make_page(page_id="valid-1"), + {"title": "No ID page", "content": [{"type": "paragraph", "content": "text", "children": []}]}, + ] + notion_mocks["notion_client"].get_all_pages.return_value = pages + + _, skipped, _ = await _run_index(notion_mocks) + + connector_docs = notion_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert connector_docs[0].unique_id == "valid-1" + assert skipped == 1 + + +async def test_pages_with_no_content_are_skipped(notion_mocks, monkeypatch): + """Pages with empty content are skipped.""" + pages = [ + _make_page(page_id="valid-1"), + _make_page(page_id="empty-1", content=[]), + ] + notion_mocks["notion_client"].get_all_pages.return_value = pages + + _, skipped, _ = await _run_index(notion_mocks) + + connector_docs = notion_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# Slice 4: Duplicate content skipping +# --------------------------------------------------------------------------- + + +async def test_duplicate_content_pages_are_skipped(notion_mocks, monkeypatch): + """Pages whose content hash matches an existing document are skipped.""" + pages = [ + _make_page(page_id="new-1"), + _make_page(page_id="dup-1"), + ] + notion_mocks["notion_client"].get_all_pages.return_value = pages + + call_count = 0 + + async def _check_dup(session, content_hash): + nonlocal call_count + call_count += 1 + if call_count == 2: + dup = MagicMock() + dup.id = 99 + dup.document_type = "OTHER" + return dup + return None + + monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup) + + _, skipped, _ = await _run_index(notion_mocks) + + connector_docs = notion_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# Slice 5: Heartbeat callback forwarding +# --------------------------------------------------------------------------- + + +async def test_heartbeat_callback_forwarded_to_pipeline(notion_mocks): + """on_heartbeat_callback is passed through to index_batch_parallel.""" + heartbeat_cb = AsyncMock() + + await _run_index(notion_mocks, on_heartbeat_callback=heartbeat_cb) + + call_kwargs = notion_mocks["batch_mock"].call_args[1] + assert call_kwargs.get("on_heartbeat") is heartbeat_cb + + +# --------------------------------------------------------------------------- +# Slice 6: Notion-specific warning messages +# --------------------------------------------------------------------------- + + +async def test_skipped_ai_content_warning_in_result(notion_mocks): + """When Notion AI content was skipped, the warning message includes it.""" + notion_mocks["notion_client"].get_skipped_content_count.return_value = 3 + + _, _, warning = await _run_index(notion_mocks) + + assert warning is not None + assert "API limitation" in warning + + +async def test_legacy_token_warning_in_result(notion_mocks): + """When using legacy token, the warning message includes a notice.""" + notion_mocks["notion_client"].is_using_legacy_token.return_value = True + + _, _, warning = await _run_index(notion_mocks) + + assert warning is not None + assert "legacy token" in warning.lower() + + +async def test_failed_docs_warning_in_result(notion_mocks): + """When documents fail indexing, the warning includes the count.""" + notion_mocks["batch_mock"].return_value = ([], 0, 2) + + _, _, warning = await _run_index(notion_mocks) + + assert warning is not None + assert "2 failed" in warning + + +# --------------------------------------------------------------------------- +# Slice 7: Empty pages early return +# --------------------------------------------------------------------------- + + +async def test_empty_pages_returns_zero_tuple(notion_mocks): + """When no pages are found, returns (0, 0, None) and updates last_indexed.""" + notion_mocks["notion_client"].get_all_pages.return_value = [] + + indexed, skipped, warning = await _run_index(notion_mocks) + + assert indexed == 0 + assert skipped == 0 + assert warning is None + + notion_mocks["batch_mock"].assert_not_called() From 683a4c17dd485fb445868e4d186e89b28a7cb008 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:31:00 +0530 Subject: [PATCH 13/31] feat: implement thread-safe embedding access in document converters - Added a reentrant lock to ensure thread-safe access to the tokenizer and embedding model, preventing runtime errors during concurrent operations. - Updated the `truncate_for_embedding` and `embed_text` functions to utilize the lock, ensuring safe execution in multi-threaded environments. - Enhanced the `embed_texts` function to maintain thread safety while processing multiple texts for embedding. --- .../app/utils/document_converters.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/surfsense_backend/app/utils/document_converters.py b/surfsense_backend/app/utils/document_converters.py index 0cacdd1d3..ed52c1b7b 100644 --- a/surfsense_backend/app/utils/document_converters.py +++ b/surfsense_backend/app/utils/document_converters.py @@ -1,5 +1,6 @@ import hashlib import logging +import threading import warnings import numpy as np @@ -11,6 +12,12 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE logger = logging.getLogger(__name__) +# HuggingFace fast tokenizers (Rust-backed) are not thread-safe — concurrent +# access from multiple threads causes "RuntimeError: Already borrowed". +# This reentrant lock serialises tokenizer + embedding model access so that +# asyncio.to_thread calls from index_batch_parallel don't collide. +_embedding_lock = threading.RLock() + def _get_embedding_max_tokens() -> int: """Get the max token limit for the configured embedding model. @@ -36,23 +43,25 @@ def truncate_for_embedding(text: str) -> str: if len(text) // 3 <= max_tokens: return text - tokenizer = config.embedding_model_instance.get_tokenizer() - tokens = tokenizer.encode(text) - if len(tokens) <= max_tokens: - return text + with _embedding_lock: + tokenizer = config.embedding_model_instance.get_tokenizer() + tokens = tokenizer.encode(text) + if len(tokens) <= max_tokens: + return text - warnings.warn( - f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.", - stacklevel=2, - ) - return tokenizer.decode(tokens[:max_tokens]) + warnings.warn( + f"Truncating text from {len(tokens)} to {max_tokens} tokens for embedding.", + stacklevel=2, + ) + return tokenizer.decode(tokens[:max_tokens]) def embed_text(text: str) -> np.ndarray: """Truncate text to fit and embed it. Drop-in replacement for ``config.embedding_model_instance.embed(text)`` that never exceeds the model's context window.""" - return config.embedding_model_instance.embed(truncate_for_embedding(text)) + with _embedding_lock: + return config.embedding_model_instance.embed(truncate_for_embedding(text)) def embed_texts(texts: list[str]) -> list[np.ndarray]: @@ -66,10 +75,11 @@ def embed_texts(texts: list[str]) -> list[np.ndarray]: """ if not texts: return [] - truncated = [truncate_for_embedding(t) for t in texts] - if config.is_local_embedding_model: - return [config.embedding_model_instance.embed(t) for t in truncated] - return config.embedding_model_instance.embed_batch(truncated) + with _embedding_lock: + truncated = [truncate_for_embedding(t) for t in texts] + if config.is_local_embedding_model: + return [config.embedding_model_instance.embed(t) for t in truncated] + return config.embedding_model_instance.embed_batch(truncated) def get_model_context_window(model_name: str) -> int: From ec79142d52dbd75615c4ff331cbaff8eb8998654 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 12:04:01 +0530 Subject: [PATCH 14/31] refactor: replace document type counts atom with real-time hook - Removed the `documentTypeCountsAtom` and its associated logic from the document query atoms. - Introduced `useZeroDocumentTypeCounts` hook to provide real-time document type counts, enhancing responsiveness as documents are indexed. - Updated components to utilize the new hook for fetching document type counts, ensuring instant updates in the UI. --- .../atoms/documents/document-query.atoms.ts | 38 ------------------- .../assistant-ui/connector-popup.tsx | 8 ++-- surfsense_web/components/onboarding-tour.tsx | 6 +-- .../hooks/use-zero-document-type-counts.ts | 31 +++++++++++++++ 4 files changed, 38 insertions(+), 45 deletions(-) delete mode 100644 surfsense_web/atoms/documents/document-query.atoms.ts create mode 100644 surfsense_web/hooks/use-zero-document-type-counts.ts diff --git a/surfsense_web/atoms/documents/document-query.atoms.ts b/surfsense_web/atoms/documents/document-query.atoms.ts deleted file mode 100644 index 656706a62..000000000 --- a/surfsense_web/atoms/documents/document-query.atoms.ts +++ /dev/null @@ -1,38 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; -import type { SearchDocumentsRequest } from "@/contracts/types/document.types"; -import { documentsApiService } from "@/lib/apis/documents-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { globalDocumentsQueryParamsAtom } from "./ui.atoms"; - -export const documentsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - const queryParams = get(globalDocumentsQueryParamsAtom); - - return { - queryKey: cacheKeys.documents.globalQueryParams(queryParams), - enabled: !!searchSpaceId, - queryFn: async () => { - return documentsApiService.getDocuments({ - queryParams: queryParams, - }); - }, - }; -}); - -export const documentTypeCountsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined), - enabled: !!searchSpaceId, - staleTime: 10 * 60 * 1000, // 10 minutes - queryFn: async () => { - return documentsApiService.getDocumentTypeCounts({ - queryParams: { - search_space_id: searchSpaceId ?? undefined, - }, - }); - }, - }; -}); diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx index 3187d3c33..ae50ed7a4 100644 --- a/surfsense_web/components/assistant-ui/connector-popup.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup.tsx @@ -4,7 +4,7 @@ import { useAtomValue, useSetAtom } from "jotai"; import { AlertTriangle, Cable, Settings } from "lucide-react"; import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react"; import { createPortal } from "react-dom"; -import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms"; +import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts"; import { statusInboxItemsAtom } from "@/atoms/inbox/status-inbox.atom"; import { globalNewLLMConfigsAtom, @@ -72,9 +72,9 @@ export const ConnectorIndicator = forwardRef | undefined { + const numericId = searchSpaceId != null ? Number(searchSpaceId) : null; + + const [zeroDocuments] = useQuery( + queries.documents.bySpace({ searchSpaceId: numericId ?? -1 }) + ); + + return useMemo(() => { + if (!zeroDocuments || numericId == null) return undefined; + + const counts: Record = {}; + for (const doc of zeroDocuments) { + if (doc.id != null && doc.title != null && doc.title !== "") { + counts[doc.documentType] = (counts[doc.documentType] || 0) + 1; + } + } + return counts; + }, [zeroDocuments, numericId]); +} From 7a2467c1ed32c0c8f1ff420d4082b878f6165a9d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 12:10:26 +0530 Subject: [PATCH 15/31] refactor: remove type counts invalidation from document mutation atoms --- surfsense_web/atoms/documents/document-mutation.atoms.ts | 9 --------- surfsense_web/lib/query-client/cache-keys.ts | 1 - 2 files changed, 10 deletions(-) diff --git a/surfsense_web/atoms/documents/document-mutation.atoms.ts b/surfsense_web/atoms/documents/document-mutation.atoms.ts index 736db896c..608862419 100644 --- a/surfsense_web/atoms/documents/document-mutation.atoms.ts +++ b/surfsense_web/atoms/documents/document-mutation.atoms.ts @@ -29,9 +29,6 @@ export const createDocumentMutationAtom = atomWithMutation((get) => { queryClient.invalidateQueries({ queryKey: cacheKeys.documents.globalQueryParams(documentsQueryParams), }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined), - }); }, }; }); @@ -75,9 +72,6 @@ export const updateDocumentMutationAtom = atomWithMutation((get) => { queryClient.invalidateQueries({ queryKey: cacheKeys.documents.document(String(request.id)), }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined), - }); }, }; }); @@ -109,9 +103,6 @@ export const deleteDocumentMutationAtom = atomWithMutation((get) => { queryClient.invalidateQueries({ queryKey: cacheKeys.documents.document(String(request.id)), }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined), - }); }, }; }); diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 3448b3fe8..883c40a77 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -17,7 +17,6 @@ export const cacheKeys = { withQueryParams: (queries: GetDocumentsRequest["queryParams"]) => ["documents-with-queries", ...(queries ? Object.values(queries) : [])] as const, document: (documentId: string) => ["document", documentId] as const, - typeCounts: (searchSpaceId?: string) => ["documents", "type-counts", searchSpaceId] as const, byChunk: (chunkId: string) => ["documents", "by-chunk", chunkId] as const, }, logs: { From 22e36d00fc689f2cea3103e7505faaf8bd1669fa Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 12:20:43 +0530 Subject: [PATCH 16/31] refactor: update bulk delete bar positioning and styling in DocumentsTableShell --- .../(manage)/components/DocumentsTableShell.tsx | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx index 6bef550d4..b32ad0ddf 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx @@ -473,14 +473,14 @@ export function DocumentsTableShell({ }, [deletableSelectedIds, bulkDeleteDocuments, deleteDocument]); const bulkDeleteBar = hasDeletableSelection ? ( -
+
) : null; @@ -526,7 +526,6 @@ export function DocumentsTableShell({ - {bulkDeleteBar} {loading ? (
@@ -594,7 +593,8 @@ export function DocumentsTableShell({ )} ) : ( -
+
+ {bulkDeleteBar}
{sorted.map((doc) => { @@ -788,9 +788,6 @@ export function DocumentsTableShell({ )} - {/* Mobile bulk delete bar */} -
{bulkDeleteBar}
- {/* Mobile Card View */} {loading ? (
@@ -846,8 +843,9 @@ export function DocumentsTableShell({ ) : (
+ {bulkDeleteBar} {sorted.map((doc) => { const isMentioned = mentionedDocIds?.has(doc.id) ?? false; const statusState = doc.status?.state ?? "ready"; From 0bc1c766ff848d844420ffb8f278f9a05a0008ca Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:02:09 +0530 Subject: [PATCH 17/31] feat: migrate Confluence and Jira indexers to unified parallel pipeline - Refactored Confluence and Jira indexers to utilize the shared IndexingPipelineService for improved document processing. - Updated the `_build_connector_doc` function in both indexers to create ConnectorDocument instances with enhanced metadata and fallback summaries. - Modified the `index_confluence_pages` and `index_jira_issues` functions to return a tuple of (indexed_count, skipped_count, warning_or_error_message) for better error handling and reporting. - Added unit tests for both indexers to validate the new parallel processing logic and ensure correct document creation and indexing behavior. --- .../connector_indexers/confluence_indexer.py | 356 +++++------------ .../tasks/connector_indexers/jira_indexer.py | 375 +++++------------- .../test_confluence_parallel.py | 373 +++++++++++++++++ .../connector_indexers/test_jira_parallel.py | 372 +++++++++++++++++ 4 files changed, 942 insertions(+), 534 deletions(-) create mode 100644 surfsense_backend/tests/unit/connector_indexers/test_confluence_parallel.py create mode 100644 surfsense_backend/tests/unit/connector_indexers/test_jira_parallel.py diff --git a/surfsense_backend/app/tasks/connector_indexers/confluence_indexer.py b/surfsense_backend/app/tasks/connector_indexers/confluence_indexer.py index 3b46b6437..8f447b842 100644 --- a/surfsense_backend/app/tasks/connector_indexers/confluence_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/confluence_indexer.py @@ -1,49 +1,74 @@ -""" -Confluence connector indexer. - -Provides real-time document status updates during indexing using a two-phase approach: -- Phase 1: Create all documents with PENDING status (visible in UI immediately) -- Phase 2: Process each document one by one (PENDING → PROCESSING → READY/FAILED) -""" +"""Confluence connector indexer using the unified parallel indexing pipeline.""" import contextlib -import time from collections.abc import Awaitable, Callable -from datetime import datetime from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.confluence_history import ConfluenceHistoryConnector -from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType +from app.db import DocumentType, SearchSourceConnectorType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_content_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.document_converters import ( - create_document_chunks, - embed_text, - generate_content_hash, - generate_document_summary, - generate_unique_identifier_hash, -) from .base import ( calculate_date_range, - check_document_by_unique_identifier, check_duplicate_document_by_hash, get_connector_by_id, - get_current_timestamp, logger, - safe_set_chunks, update_connector_last_indexed, ) -# Type hint for heartbeat callback HeartbeatCallbackType = Callable[[int], Awaitable[None]] - -# Heartbeat interval in seconds HEARTBEAT_INTERVAL_SECONDS = 30 +def _build_connector_doc( + page: dict, + full_content: str, + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, +) -> ConnectorDocument: + """Map a raw Confluence page dict to a ConnectorDocument.""" + page_id = page.get("id", "") + page_title = page.get("title", "") + space_id = page.get("spaceId", "") + comment_count = len(page.get("comments", [])) + + metadata = { + "page_id": page_id, + "page_title": page_title, + "space_id": space_id, + "comment_count": comment_count, + "connector_id": connector_id, + "document_type": "Confluence Page", + "connector_type": "Confluence", + } + + fallback_summary = ( + f"Confluence Page: {page_title}\n\nSpace ID: {space_id}\n\n{full_content}" + ) + + return ConnectorDocument( + title=page_title, + source_markdown=full_content, + unique_id=page_id, + document_type=DocumentType.CONFLUENCE_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=enable_summary, + fallback_summary=fallback_summary, + metadata=metadata, + ) + + async def index_confluence_pages( session: AsyncSession, connector_id: int, @@ -53,26 +78,9 @@ async def index_confluence_pages( end_date: str | None = None, update_last_indexed: bool = True, on_heartbeat_callback: HeartbeatCallbackType | None = None, -) -> tuple[int, str | None]: - """ - Index Confluence pages and comments. - - Args: - session: Database session - connector_id: ID of the Confluence connector - search_space_id: ID of the search space to store documents in - user_id: User ID - start_date: Start date for indexing (YYYY-MM-DD format) - end_date: End date for indexing (YYYY-MM-DD format) - update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - on_heartbeat_callback: Optional callback to update notification during long-running indexing. - - Returns: - Tuple containing (number of documents indexed, error message or None) - """ +) -> tuple[int, int, str | None]: + """Index Confluence pages and comments.""" task_logger = TaskLoggingService(session, search_space_id) - - # Log task start log_entry = await task_logger.log_task_start( task_name="confluence_pages_indexing", source="connector_indexing_task", @@ -86,7 +94,6 @@ async def index_confluence_pages( ) try: - # Get the connector from the database connector = await get_connector_by_id( session, connector_id, SearchSourceConnectorType.CONFLUENCE_CONNECTOR ) @@ -98,9 +105,8 @@ async def index_confluence_pages( "Connector not found", {"error_type": "ConnectorNotFound"}, ) - return 0, f"Connector with ID {connector_id} not found" + return 0, 0, f"Connector with ID {connector_id} not found" - # Initialize Confluence OAuth client await task_logger.log_task_progress( log_entry, f"Initializing Confluence OAuth client for connector {connector_id}", @@ -114,7 +120,6 @@ async def index_confluence_pages( ) ) - # Calculate date range start_date_str, end_date_str = calculate_date_range( connector, start_date, end_date, default_days_back=365 ) @@ -129,19 +134,14 @@ async def index_confluence_pages( }, ) - # Get pages within date range try: pages, error = await confluence_client.get_pages_by_date_range( start_date=start_date_str, end_date=end_date_str, include_comments=True ) if error: - # Don't treat "No pages found" as an error that should stop indexing if "No pages found" in error: logger.info(f"No Confluence pages found: {error}") - logger.info( - "No pages found is not a critical error, continuing with update" - ) if update_last_indexed: await update_connector_last_indexed( session, connector, update_last_indexed @@ -156,11 +156,10 @@ async def index_confluence_pages( f"No Confluence pages found in date range {start_date_str} to {end_date_str}", {"pages_found": 0}, ) - # Close client before returning if confluence_client: with contextlib.suppress(Exception): await confluence_client.close() - return 0, None + return 0, 0, None else: logger.error(f"Failed to get Confluence pages: {error}") await task_logger.log_task_failure( @@ -169,36 +168,35 @@ async def index_confluence_pages( "API Error", {"error_type": "APIError"}, ) - # Close client on error if confluence_client: with contextlib.suppress(Exception): await confluence_client.close() - return 0, f"Failed to get Confluence pages: {error}" + return 0, 0, f"Failed to get Confluence pages: {error}" logger.info(f"Retrieved {len(pages)} pages from Confluence API") except Exception as e: logger.error(f"Error fetching Confluence pages: {e!s}", exc_info=True) - # Close client on error if confluence_client: with contextlib.suppress(Exception): await confluence_client.close() - return 0, f"Error fetching Confluence pages: {e!s}" + return 0, 0, f"Error fetching Confluence pages: {e!s}" + + if not pages: + logger.info("No Confluence pages found for the specified date range") + if update_last_indexed: + await update_connector_last_indexed( + session, connector, update_last_indexed + ) + await session.commit() + if confluence_client: + with contextlib.suppress(Exception): + await confluence_client.close() + return 0, 0, None - # ======================================================================= - # PHASE 1: Analyze all pages, create pending documents - # This makes ALL documents visible in the UI immediately with pending status - # ======================================================================= - documents_indexed = 0 documents_skipped = 0 - documents_failed = 0 duplicate_content_count = 0 - - # Heartbeat tracking - update notification periodically to prevent appearing stuck - last_heartbeat_time = time.time() - - pages_to_process = [] # List of dicts with document and page data - new_documents_created = False + connector_docs: list[ConnectorDocument] = [] for page in pages: try: @@ -213,12 +211,10 @@ async def index_confluence_pages( documents_skipped += 1 continue - # Extract page content page_content = "" if page.get("body") and page["body"].get("storage"): page_content = page["body"]["storage"].get("value", "") - # Add comments to content comments = page.get("comments", []) comments_content = "" if comments: @@ -235,61 +231,25 @@ async def index_confluence_pages( comments_content += f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n" - # Combine page content with comments full_content = f"# {page_title}\n\n{page_content}{comments_content}" - if not full_content.strip(): + if not page_content.strip() and not comments: logger.warning(f"Skipping page with no content: {page_title}") documents_skipped += 1 continue - # Generate unique identifier hash for this Confluence page - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.CONFLUENCE_CONNECTOR, page_id, search_space_id + doc = _build_connector_doc( + page, + full_content, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=connector.enable_summary, ) - # Generate content hash - content_hash = generate_content_hash(full_content, search_space_id) - - # Check if document with this unique identifier already exists - existing_document = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - - comment_count = len(comments) - - if existing_document: - # Document exists - check if content has changed - if existing_document.content_hash == content_hash: - # Ensure status is ready (might have been stuck in processing/pending) - if not DocumentStatus.is_state( - existing_document.status, DocumentStatus.READY - ): - existing_document.status = DocumentStatus.ready() - documents_skipped += 1 - continue - - # Queue existing document for update (will be set to processing in Phase 2) - pages_to_process.append( - { - "document": existing_document, - "is_new": False, - "full_content": full_content, - "page_content": page_content, - "content_hash": content_hash, - "page_id": page_id, - "page_title": page_title, - "space_id": space_id, - "comment_count": comment_count, - } - ) - continue - - # Document doesn't exist by unique_identifier_hash - # Check if a document with the same content_hash exists (from another connector) with session.no_autoflush: duplicate_by_content = await check_duplicate_document_by_hash( - session, content_hash + session, compute_content_hash(doc) ) if duplicate_by_content: @@ -302,151 +262,29 @@ async def index_confluence_pages( documents_skipped += 1 continue - # Create new document with PENDING status (visible in UI immediately) - document = Document( - search_space_id=search_space_id, - title=page_title, - document_type=DocumentType.CONFLUENCE_CONNECTOR, - document_metadata={ - "page_id": page_id, - "page_title": page_title, - "space_id": space_id, - "comment_count": comment_count, - "connector_id": connector_id, - }, - content="Pending...", # Placeholder until processed - content_hash=unique_identifier_hash, # Temporary unique value - updated when ready - unique_identifier_hash=unique_identifier_hash, - embedding=None, - chunks=[], # Empty at creation - safe for async - status=DocumentStatus.pending(), # Pending until processing starts - updated_at=get_current_timestamp(), - created_by_id=user_id, - connector_id=connector_id, - ) - session.add(document) - new_documents_created = True - - pages_to_process.append( - { - "document": document, - "is_new": True, - "full_content": full_content, - "page_content": page_content, - "content_hash": content_hash, - "page_id": page_id, - "page_title": page_title, - "space_id": space_id, - "comment_count": comment_count, - } - ) + connector_docs.append(doc) except Exception as e: - logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True) - documents_failed += 1 + logger.error(f"Error building ConnectorDocument for page: {e!s}", exc_info=True) + documents_skipped += 1 continue - # Commit all pending documents - they all appear in UI now - if new_documents_created: - logger.info( - f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents" - ) - await session.commit() + pipeline = IndexingPipelineService(session) + await pipeline.migrate_legacy_docs(connector_docs) - # ======================================================================= - # PHASE 2: Process each document one by one - # Each document transitions: pending → processing → ready/failed - # ======================================================================= - logger.info(f"Phase 2: Processing {len(pages_to_process)} documents") + async def _get_llm(s: AsyncSession): + return await get_user_long_context_llm(s, user_id, search_space_id) - for item in pages_to_process: - # Send heartbeat periodically - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time + _, documents_indexed, documents_failed = await pipeline.index_batch_parallel( + connector_docs, + _get_llm, + max_concurrency=3, + on_heartbeat=on_heartbeat_callback, + heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS, + ) - document = item["document"] - try: - # Set to PROCESSING and commit - shows "processing" in UI for THIS document only - document.status = DocumentStatus.processing() - await session.commit() - - # Heavy processing (LLM, embeddings, chunks) - user_llm = await get_user_long_context_llm( - session, user_id, search_space_id - ) - - if user_llm and connector.enable_summary: - document_metadata = { - "page_title": item["page_title"], - "page_id": item["page_id"], - "space_id": item["space_id"], - "comment_count": item["comment_count"], - "document_type": "Confluence Page", - "connector_type": "Confluence", - } - ( - summary_content, - summary_embedding, - ) = await generate_document_summary( - item["full_content"], user_llm, document_metadata - ) - else: - summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}" - summary_embedding = embed_text(summary_content) - - # Process chunks - using the full page content with comments - chunks = await create_document_chunks(item["full_content"]) - - # Update document to READY with actual content - document.title = item["page_title"] - document.content = summary_content - document.content_hash = item["content_hash"] - document.embedding = summary_embedding - document.document_metadata = { - "page_id": item["page_id"], - "page_title": item["page_title"], - "space_id": item["space_id"], - "comment_count": item["comment_count"], - "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "connector_id": connector_id, - } - await safe_set_chunks(session, document, chunks) - document.updated_at = get_current_timestamp() - document.status = DocumentStatus.ready() - - documents_indexed += 1 - - # Batch commit every 10 documents (for ready status updates) - if documents_indexed % 10 == 0: - logger.info( - f"Committing batch: {documents_indexed} Confluence pages processed so far" - ) - await session.commit() - - except Exception as e: - logger.error( - f"Error processing page {item.get('page_title', 'Unknown')}: {e!s}", - exc_info=True, - ) - # Mark document as failed with reason (visible in UI) - try: - document.status = DocumentStatus.failed(str(e)) - document.updated_at = get_current_timestamp() - except Exception as status_error: - logger.error( - f"Failed to update document status to failed: {status_error}" - ) - documents_failed += 1 - continue # Skip this page and continue with others - - # CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs - # This ensures the UI shows "Last indexed" instead of "Never indexed" await update_connector_last_indexed(session, connector, update_last_indexed) - # Final commit to ensure all documents are persisted (safety net) logger.info( f"Final commit: Total {documents_indexed} Confluence pages processed" ) @@ -456,7 +294,6 @@ async def index_confluence_pages( "Successfully committed all Confluence document changes to database" ) except Exception as e: - # Handle any remaining integrity errors gracefully (race conditions, etc.) if ( "duplicate key value violates unique constraint" in str(e).lower() or "uniqueviolationerror" in str(e).lower() @@ -467,11 +304,9 @@ async def index_confluence_pages( f"Rolling back and continuing. Error: {e!s}" ) await session.rollback() - # Don't fail the entire task - some documents may have been successfully indexed else: raise - # Build warning message if there were issues warning_parts = [] if duplicate_content_count > 0: warning_parts.append(f"{duplicate_content_count} duplicate") @@ -479,7 +314,6 @@ async def index_confluence_pages( warning_parts.append(f"{documents_failed} failed") warning_message = ", ".join(warning_parts) if warning_parts else None - # Log success await task_logger.log_task_success( log_entry, f"Successfully completed Confluence indexing for connector {connector_id}", @@ -490,22 +324,19 @@ async def index_confluence_pages( "duplicate_content_count": duplicate_content_count, }, ) - logger.info( f"Confluence indexing completed: {documents_indexed} ready, " f"{documents_skipped} skipped, {documents_failed} failed " f"({duplicate_content_count} duplicate content)" ) - # Close the client connection if confluence_client: await confluence_client.close() - return documents_indexed, warning_message + return documents_indexed, documents_skipped, warning_message except SQLAlchemyError as db_error: await session.rollback() - # Close client if it exists if confluence_client: with contextlib.suppress(Exception): await confluence_client.close() @@ -516,10 +347,9 @@ async def index_confluence_pages( {"error_type": "SQLAlchemyError"}, ) logger.error(f"Database error: {db_error!s}", exc_info=True) - return 0, f"Database error: {db_error!s}" + return 0, 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() - # Close client if it exists if confluence_client: with contextlib.suppress(Exception): await confluence_client.close() @@ -530,4 +360,4 @@ async def index_confluence_pages( {"error_type": type(e).__name__}, ) logger.error(f"Failed to index Confluence pages: {e!s}", exc_info=True) - return 0, f"Failed to index Confluence pages: {e!s}" + return 0, 0, f"Failed to index Confluence pages: {e!s}" diff --git a/surfsense_backend/app/tasks/connector_indexers/jira_indexer.py b/surfsense_backend/app/tasks/connector_indexers/jira_indexer.py index 25491a8f6..6e370fc35 100644 --- a/surfsense_backend/app/tasks/connector_indexers/jira_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/jira_indexer.py @@ -1,49 +1,80 @@ -""" -Jira connector indexer. - -Provides real-time document status updates during indexing using a two-phase approach: -- Phase 1: Create all documents with PENDING status (visible in UI immediately) -- Phase 2: Process each document one by one (PENDING → PROCESSING → READY/FAILED) -""" +"""Jira connector indexer using the unified parallel indexing pipeline.""" import contextlib -import time from collections.abc import Awaitable, Callable -from datetime import datetime from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.jira_history import JiraHistoryConnector -from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType +from app.db import DocumentType, SearchSourceConnectorType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_content_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.utils.document_converters import ( - create_document_chunks, - embed_text, - generate_content_hash, - generate_document_summary, - generate_unique_identifier_hash, -) from .base import ( calculate_date_range, - check_document_by_unique_identifier, check_duplicate_document_by_hash, get_connector_by_id, - get_current_timestamp, logger, - safe_set_chunks, update_connector_last_indexed, ) -# Type hint for heartbeat callback HeartbeatCallbackType = Callable[[int], Awaitable[None]] - -# Heartbeat interval in seconds - update notification every 30 seconds HEARTBEAT_INTERVAL_SECONDS = 30 +def _build_connector_doc( + issue: dict, + formatted_issue: dict, + issue_content: str, + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, +) -> ConnectorDocument: + """Map a raw Jira issue dict to a ConnectorDocument.""" + issue_id = issue.get("key", "") + issue_identifier = issue.get("key", "") + issue_title = issue.get("id", "") + state = formatted_issue.get("status", "Unknown") + priority = formatted_issue.get("priority", "Unknown") + comment_count = len(formatted_issue.get("comments", [])) + + metadata = { + "issue_id": issue_id, + "issue_identifier": issue_identifier, + "issue_title": issue_title, + "state": state, + "priority": priority, + "comment_count": comment_count, + "connector_id": connector_id, + "document_type": "Jira Issue", + "connector_type": "Jira", + } + + fallback_summary = ( + f"Jira Issue {issue_identifier}: {issue_title}\n\n" + f"Status: {state}\n\n{issue_content}" + ) + + return ConnectorDocument( + title=f"{issue_identifier}: {issue_title}", + source_markdown=issue_content, + unique_id=issue_id, + document_type=DocumentType.JIRA_CONNECTOR, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=enable_summary, + fallback_summary=fallback_summary, + metadata=metadata, + ) + + async def index_jira_issues( session: AsyncSession, connector_id: int, @@ -53,26 +84,9 @@ async def index_jira_issues( end_date: str | None = None, update_last_indexed: bool = True, on_heartbeat_callback: HeartbeatCallbackType | None = None, -) -> tuple[int, str | None]: - """ - Index Jira issues and comments. - - Args: - session: Database session - connector_id: ID of the Jira connector - search_space_id: ID of the search space to store documents in - user_id: User ID - start_date: Start date for indexing (YYYY-MM-DD format) - end_date: End date for indexing (YYYY-MM-DD format) - update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) - on_heartbeat_callback: Optional callback to update notification during long-running indexing. - - Returns: - Tuple containing (number of documents indexed, error message or None) - """ +) -> tuple[int, int, str | None]: + """Index Jira issues and comments.""" task_logger = TaskLoggingService(session, search_space_id) - - # Log task start log_entry = await task_logger.log_task_start( task_name="jira_issues_indexing", source="connector_indexing_task", @@ -86,7 +100,6 @@ async def index_jira_issues( ) try: - # Get the connector from the database connector = await get_connector_by_id( session, connector_id, SearchSourceConnectorType.JIRA_CONNECTOR ) @@ -98,24 +111,15 @@ async def index_jira_issues( "Connector not found", {"error_type": "ConnectorNotFound"}, ) - return 0, f"Connector with ID {connector_id} not found" + return 0, 0, f"Connector with ID {connector_id} not found" - # Initialize Jira client with internal refresh capability - # Token refresh will happen automatically when needed await task_logger.log_task_progress( log_entry, f"Initializing Jira client for connector {connector_id}", {"stage": "client_initialization"}, ) - - logger.info(f"Initializing Jira client for connector {connector_id}") - - # Create connector with session and connector_id for internal refresh - # Token refresh will happen automatically when needed jira_client = JiraHistoryConnector(session=session, connector_id=connector_id) - # Calculate date range - # Handle "undefined" strings from frontend if start_date == "undefined" or start_date == "": start_date = None if end_date == "undefined" or end_date == "": @@ -135,19 +139,14 @@ async def index_jira_issues( }, ) - # Get issues within date range try: issues, error = await jira_client.get_issues_by_date_range( start_date=start_date_str, end_date=end_date_str, include_comments=True ) if error: - # Don't treat "No issues found" as an error that should stop indexing if "No issues found" in error: logger.info(f"No Jira issues found: {error}") - logger.info( - "No issues found is not a critical error, continuing with update" - ) if update_last_indexed: await update_connector_last_indexed( session, connector, update_last_indexed @@ -162,7 +161,8 @@ async def index_jira_issues( f"No Jira issues found in date range {start_date_str} to {end_date_str}", {"issues_found": 0}, ) - return 0, None + await jira_client.close() + return 0, 0, None else: logger.error(f"Failed to get Jira issues: {error}") await task_logger.log_task_failure( @@ -171,29 +171,30 @@ async def index_jira_issues( "API Error", {"error_type": "APIError"}, ) - return 0, f"Failed to get Jira issues: {error}" + await jira_client.close() + return 0, 0, f"Failed to get Jira issues: {error}" logger.info(f"Retrieved {len(issues)} issues from Jira API") except Exception as e: logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True) - return 0, f"Error fetching Jira issues: {e!s}" + await jira_client.close() + return 0, 0, f"Error fetching Jira issues: {e!s}" - # ======================================================================= - # PHASE 1: Analyze all issues, create pending documents - # This makes ALL documents visible in the UI immediately with pending status - # ======================================================================= - documents_indexed = 0 + if not issues: + logger.info("No Jira issues found for the specified date range") + if update_last_indexed: + await update_connector_last_indexed( + session, connector, update_last_indexed + ) + await session.commit() + await jira_client.close() + return 0, 0, None + + connector_docs: list[ConnectorDocument] = [] documents_skipped = 0 - documents_failed = 0 duplicate_content_count = 0 - # Heartbeat tracking - update notification periodically to prevent appearing stuck - last_heartbeat_time = time.time() - - issues_to_process = [] # List of dicts with document and issue data - new_documents_created = False - for issue in issues: try: issue_id = issue.get("key") @@ -207,10 +208,7 @@ async def index_jira_issues( documents_skipped += 1 continue - # Format the issue for better readability formatted_issue = jira_client.format_issue(issue) - - # Convert to markdown issue_content = jira_client.format_issue_to_markdown(formatted_issue) if not issue_content: @@ -220,53 +218,19 @@ async def index_jira_issues( documents_skipped += 1 continue - # Generate unique identifier hash for this Jira issue - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.JIRA_CONNECTOR, issue_id, search_space_id + doc = _build_connector_doc( + issue, + formatted_issue, + issue_content, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=connector.enable_summary, ) - # Generate content hash - content_hash = generate_content_hash(issue_content, search_space_id) - - # Check if document with this unique identifier already exists - existing_document = await check_document_by_unique_identifier( - session, unique_identifier_hash - ) - - comment_count = len(formatted_issue.get("comments", [])) - - if existing_document: - # Document exists - check if content has changed - if existing_document.content_hash == content_hash: - # Ensure status is ready (might have been stuck in processing/pending) - if not DocumentStatus.is_state( - existing_document.status, DocumentStatus.READY - ): - existing_document.status = DocumentStatus.ready() - documents_skipped += 1 - continue - - # Queue existing document for update (will be set to processing in Phase 2) - issues_to_process.append( - { - "document": existing_document, - "is_new": False, - "issue_content": issue_content, - "content_hash": content_hash, - "issue_id": issue_id, - "issue_identifier": issue_identifier, - "issue_title": issue_title, - "formatted_issue": formatted_issue, - "comment_count": comment_count, - } - ) - continue - - # Document doesn't exist by unique_identifier_hash - # Check if a document with the same content_hash exists (from another connector) with session.no_autoflush: duplicate_by_content = await check_duplicate_document_by_hash( - session, content_hash + session, compute_content_hash(doc) ) if duplicate_by_content: @@ -279,160 +243,37 @@ async def index_jira_issues( documents_skipped += 1 continue - # Create new document with PENDING status (visible in UI immediately) - document = Document( - search_space_id=search_space_id, - title=f"{issue_identifier}: {issue_title}", - document_type=DocumentType.JIRA_CONNECTOR, - document_metadata={ - "issue_id": issue_id, - "issue_identifier": issue_identifier, - "issue_title": issue_title, - "state": formatted_issue.get("status", "Unknown"), - "comment_count": comment_count, - "connector_id": connector_id, - }, - content="Pending...", # Placeholder until processed - content_hash=unique_identifier_hash, # Temporary unique value - updated when ready - unique_identifier_hash=unique_identifier_hash, - embedding=None, - chunks=[], # Empty at creation - safe for async - status=DocumentStatus.pending(), # Pending until processing starts - updated_at=get_current_timestamp(), - created_by_id=user_id, - connector_id=connector_id, - ) - session.add(document) - new_documents_created = True - - issues_to_process.append( - { - "document": document, - "is_new": True, - "issue_content": issue_content, - "content_hash": content_hash, - "issue_id": issue_id, - "issue_identifier": issue_identifier, - "issue_title": issue_title, - "formatted_issue": formatted_issue, - "comment_count": comment_count, - } - ) - - except Exception as e: - logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True) - documents_failed += 1 - continue - - # Commit all pending documents - they all appear in UI now - if new_documents_created: - logger.info( - f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents" - ) - await session.commit() - - # ======================================================================= - # PHASE 2: Process each document one by one - # Each document transitions: pending → processing → ready/failed - # ======================================================================= - logger.info(f"Phase 2: Processing {len(issues_to_process)} documents") - - for item in issues_to_process: - # Send heartbeat periodically - if on_heartbeat_callback: - current_time = time.time() - if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: - await on_heartbeat_callback(documents_indexed) - last_heartbeat_time = current_time - - document = item["document"] - try: - # Set to PROCESSING and commit - shows "processing" in UI for THIS document only - document.status = DocumentStatus.processing() - await session.commit() - - # Heavy processing (LLM, embeddings, chunks) - user_llm = await get_user_long_context_llm( - session, user_id, search_space_id - ) - - if user_llm and connector.enable_summary: - document_metadata = { - "issue_key": item["issue_identifier"], - "issue_title": item["issue_title"], - "status": item["formatted_issue"].get("status", "Unknown"), - "priority": item["formatted_issue"].get("priority", "Unknown"), - "comment_count": item["comment_count"], - "document_type": "Jira Issue", - "connector_type": "Jira", - } - ( - summary_content, - summary_embedding, - ) = await generate_document_summary( - item["issue_content"], user_llm, document_metadata - ) - else: - summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\n{item['issue_content']}" - summary_embedding = embed_text(summary_content) - - # Process chunks - using the full issue content with comments - chunks = await create_document_chunks(item["issue_content"]) - - # Update document to READY with actual content - document.title = f"{item['issue_identifier']}: {item['issue_title']}" - document.content = summary_content - document.content_hash = item["content_hash"] - document.embedding = summary_embedding - document.document_metadata = { - "issue_id": item["issue_id"], - "issue_identifier": item["issue_identifier"], - "issue_title": item["issue_title"], - "state": item["formatted_issue"].get("status", "Unknown"), - "comment_count": item["comment_count"], - "indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "connector_id": connector_id, - } - await safe_set_chunks(session, document, chunks) - document.updated_at = get_current_timestamp() - document.status = DocumentStatus.ready() - - documents_indexed += 1 - - # Batch commit every 10 documents (for ready status updates) - if documents_indexed % 10 == 0: - logger.info( - f"Committing batch: {documents_indexed} Jira issues processed so far" - ) - await session.commit() + connector_docs.append(doc) except Exception as e: logger.error( - f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}", + f"Error building ConnectorDocument for issue {issue_identifier}: {e!s}", exc_info=True, ) - # Mark document as failed with reason (visible in UI) - try: - document.status = DocumentStatus.failed(str(e)) - document.updated_at = get_current_timestamp() - except Exception as status_error: - logger.error( - f"Failed to update document status to failed: {status_error}" - ) - documents_failed += 1 - continue # Skip this issue and continue with others + documents_skipped += 1 + continue + + pipeline = IndexingPipelineService(session) + await pipeline.migrate_legacy_docs(connector_docs) + + async def _get_llm(s: AsyncSession): + return await get_user_long_context_llm(s, user_id, search_space_id) + + _, documents_indexed, documents_failed = await pipeline.index_batch_parallel( + connector_docs, + _get_llm, + max_concurrency=3, + on_heartbeat=on_heartbeat_callback, + heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS, + ) - # CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs - # This ensures the UI shows "Last indexed" instead of "Never indexed" await update_connector_last_indexed(session, connector, update_last_indexed) - # Final commit to ensure all documents are persisted (safety net) logger.info(f"Final commit: Total {documents_indexed} Jira issues processed") try: await session.commit() logger.info("Successfully committed all JIRA document changes to database") except Exception as e: - # Handle any remaining integrity errors gracefully (race conditions, etc.) if ( "duplicate key value violates unique constraint" in str(e).lower() or "uniqueviolationerror" in str(e).lower() @@ -447,7 +288,6 @@ async def index_jira_issues( else: raise - # Build warning message if there were issues warning_parts = [] if duplicate_content_count > 0: warning_parts.append(f"{duplicate_content_count} duplicate") @@ -455,7 +295,6 @@ async def index_jira_issues( warning_parts.append(f"{documents_failed} failed") warning_message = ", ".join(warning_parts) if warning_parts else None - # Log success await task_logger.log_task_success( log_entry, f"Successfully completed JIRA indexing for connector {connector_id}", @@ -466,17 +305,13 @@ async def index_jira_issues( "duplicate_content_count": duplicate_content_count, }, ) - logger.info( f"JIRA indexing completed: {documents_indexed} ready, " f"{documents_skipped} skipped, {documents_failed} failed " f"({duplicate_content_count} duplicate content)" ) - - # Clean up the connector await jira_client.close() - - return documents_indexed, warning_message + return documents_indexed, documents_skipped, warning_message except SQLAlchemyError as db_error: await session.rollback() @@ -487,11 +322,10 @@ async def index_jira_issues( {"error_type": "SQLAlchemyError"}, ) logger.error(f"Database error: {db_error!s}", exc_info=True) - # Clean up the connector in case of error if "jira_client" in locals(): with contextlib.suppress(Exception): await jira_client.close() - return 0, f"Database error: {db_error!s}" + return 0, 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( @@ -501,8 +335,7 @@ async def index_jira_issues( {"error_type": type(e).__name__}, ) logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True) - # Clean up the connector in case of error if "jira_client" in locals(): with contextlib.suppress(Exception): await jira_client.close() - return 0, f"Failed to index JIRA issues: {e!s}" + return 0, 0, f"Failed to index JIRA issues: {e!s}" diff --git a/surfsense_backend/tests/unit/connector_indexers/test_confluence_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_confluence_parallel.py new file mode 100644 index 000000000..d11c88b64 --- /dev/null +++ b/surfsense_backend/tests/unit/connector_indexers/test_confluence_parallel.py @@ -0,0 +1,373 @@ +"""Tests for Confluence indexer migrated to the unified parallel pipeline.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import app.tasks.connector_indexers.confluence_indexer as _mod +from app.db import DocumentType +from app.tasks.connector_indexers.confluence_indexer import ( + _build_connector_doc, + index_confluence_pages, +) + +pytestmark = pytest.mark.unit + +_USER_ID = "00000000-0000-0000-0000-000000000001" +_CONNECTOR_ID = 42 +_SEARCH_SPACE_ID = 1 + + +def _make_page( + page_id: str = "p1", + title: str = "Home", + space_id: str = "S1", + body: str = "

Hello

", + comments=None, +): + return { + "id": page_id, + "title": title, + "spaceId": space_id, + "body": {"storage": {"value": body}}, + "comments": comments or [], + } + + +def _to_markdown(page: dict) -> str: + page_title = page.get("title", "") + page_content = page.get("body", {}).get("storage", {}).get("value", "") + comments = page.get("comments", []) + comments_content = "" + if comments: + comments_content = "\n\n## Comments\n\n" + for comment in comments: + comment_body = ( + comment.get("body", {}).get("storage", {}).get("value", "") + ) + comment_author = comment.get("version", {}).get("authorId", "Unknown") + comment_date = comment.get("version", {}).get("createdAt", "") + comments_content += ( + f"**Comment by {comment_author}** ({comment_date}):\n" + f"{comment_body}\n\n" + ) + return f"# {page_title}\n\n{page_content}{comments_content}" + + +# --------------------------------------------------------------------------- +# Slice 1: _build_connector_doc tracer bullet +# --------------------------------------------------------------------------- + + +async def test_build_connector_doc_produces_correct_fields(): + page = _make_page( + page_id="abc-123", + title="Engineering Handbook", + space_id="ENG", + comments=[{"id": "c1"}], + ) + markdown = _to_markdown(page) + + doc = _build_connector_doc( + page, + markdown, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + assert doc.title == "Engineering Handbook" + assert doc.unique_id == "abc-123" + assert doc.document_type == DocumentType.CONFLUENCE_CONNECTOR + assert doc.source_markdown == markdown + assert doc.search_space_id == _SEARCH_SPACE_ID + assert doc.connector_id == _CONNECTOR_ID + assert doc.created_by_id == _USER_ID + assert doc.should_summarize is True + assert doc.metadata["page_id"] == "abc-123" + assert doc.metadata["page_title"] == "Engineering Handbook" + assert doc.metadata["space_id"] == "ENG" + assert doc.metadata["comment_count"] == 1 + assert doc.metadata["connector_id"] == _CONNECTOR_ID + assert doc.metadata["document_type"] == "Confluence Page" + assert doc.metadata["connector_type"] == "Confluence" + assert doc.fallback_summary is not None + assert "Engineering Handbook" in doc.fallback_summary + assert markdown in doc.fallback_summary + + +async def test_build_connector_doc_summary_disabled(): + doc = _build_connector_doc( + _make_page(), + _to_markdown(_make_page()), + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=False, + ) + assert doc.should_summarize is False + + +# --------------------------------------------------------------------------- +# Shared fixtures for Slices 2-7 +# --------------------------------------------------------------------------- + + +def _mock_connector(enable_summary: bool = True): + c = MagicMock() + c.config = {"access_token": "tok"} + c.enable_summary = enable_summary + c.last_indexed_at = None + return c + + +def _mock_confluence_client(pages=None, error=None): + client = MagicMock() + client.get_pages_by_date_range = AsyncMock( + return_value=(pages if pages is not None else [], error), + ) + client.close = AsyncMock() + return client + + +@pytest.fixture +def confluence_mocks(monkeypatch): + mock_session = AsyncMock() + mock_session.no_autoflush = MagicMock() + + mock_connector = _mock_connector() + monkeypatch.setattr( + _mod, "get_connector_by_id", AsyncMock(return_value=mock_connector), + ) + + confluence_client = _mock_confluence_client(pages=[_make_page()]) + monkeypatch.setattr( + _mod, "ConfluenceHistoryConnector", MagicMock(return_value=confluence_client), + ) + + monkeypatch.setattr( + _mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None), + ) + monkeypatch.setattr( + _mod, "update_connector_last_indexed", AsyncMock(), + ) + monkeypatch.setattr( + _mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")), + ) + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_progress = AsyncMock() + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_failure = AsyncMock() + monkeypatch.setattr( + _mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger), + ) + + batch_mock = AsyncMock(return_value=([], 1, 0)) + pipeline_mock = MagicMock() + pipeline_mock.index_batch_parallel = batch_mock + pipeline_mock.migrate_legacy_docs = AsyncMock() + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock), + ) + + return { + "session": mock_session, + "connector": mock_connector, + "confluence_client": confluence_client, + "task_logger": mock_task_logger, + "pipeline_mock": pipeline_mock, + "batch_mock": batch_mock, + } + + +async def _run_index(mocks, **overrides): + return await index_confluence_pages( + session=mocks["session"], + connector_id=overrides.get("connector_id", _CONNECTOR_ID), + search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID), + user_id=overrides.get("user_id", _USER_ID), + start_date=overrides.get("start_date", "2025-01-01"), + end_date=overrides.get("end_date", "2025-12-31"), + update_last_indexed=overrides.get("update_last_indexed", True), + on_heartbeat_callback=overrides.get("on_heartbeat_callback"), + ) + + +# --------------------------------------------------------------------------- +# Slice 2: Full pipeline wiring +# --------------------------------------------------------------------------- + + +async def test_one_page_calls_pipeline_and_returns_indexed_count(confluence_mocks): + indexed, skipped, warning = await _run_index(confluence_mocks) + assert indexed == 1 + assert skipped == 0 + assert warning is None + + confluence_mocks["batch_mock"].assert_called_once() + connector_docs = confluence_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert connector_docs[0].document_type == DocumentType.CONFLUENCE_CONNECTOR + + +async def test_pipeline_called_with_max_concurrency_3(confluence_mocks): + await _run_index(confluence_mocks) + call_kwargs = confluence_mocks["batch_mock"].call_args[1] + assert call_kwargs.get("max_concurrency") == 3 + + +async def test_migrate_legacy_docs_called_before_indexing(confluence_mocks): + await _run_index(confluence_mocks) + confluence_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once() + + +# --------------------------------------------------------------------------- +# Slice 3: Page skipping (missing id/title/content) +# --------------------------------------------------------------------------- + + +async def test_pages_with_missing_id_are_skipped(confluence_mocks): + pages = [ + _make_page(page_id="p1", title="Valid"), + _make_page(page_id="", title="Missing id"), + ] + confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = ( + pages, + None, + ) + _, skipped, _ = await _run_index(confluence_mocks) + connector_docs = confluence_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +async def test_pages_with_missing_title_are_skipped(confluence_mocks): + pages = [ + _make_page(page_id="p1", title="Valid"), + _make_page(page_id="p2", title=""), + ] + confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = ( + pages, + None, + ) + _, skipped, _ = await _run_index(confluence_mocks) + connector_docs = confluence_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +async def test_pages_with_no_content_are_skipped(confluence_mocks): + pages = [ + _make_page(page_id="p1", title="Valid", body="

ok

"), + _make_page(page_id="p2", title="Empty", body=""), + ] + confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = ( + pages, + None, + ) + _, skipped, _ = await _run_index(confluence_mocks) + connector_docs = confluence_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# Slice 4: Duplicate content skipping +# --------------------------------------------------------------------------- + + +async def test_duplicate_content_pages_are_skipped(confluence_mocks, monkeypatch): + pages = [ + _make_page(page_id="p1", title="One"), + _make_page(page_id="p2", title="Two"), + ] + confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = ( + pages, + None, + ) + + call_count = 0 + + async def _check_dup(session, content_hash): + nonlocal call_count + call_count += 1 + if call_count == 2: + dup = MagicMock() + dup.id = 99 + dup.document_type = "OTHER" + return dup + return None + + monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup) + + _, skipped, _ = await _run_index(confluence_mocks) + connector_docs = confluence_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# Slice 5: Heartbeat callback forwarding +# --------------------------------------------------------------------------- + + +async def test_heartbeat_callback_forwarded_to_pipeline(confluence_mocks): + heartbeat_cb = AsyncMock() + await _run_index(confluence_mocks, on_heartbeat_callback=heartbeat_cb) + call_kwargs = confluence_mocks["batch_mock"].call_args[1] + assert call_kwargs.get("on_heartbeat") is heartbeat_cb + + +# --------------------------------------------------------------------------- +# Slice 6: Empty pages and no-data success tuple +# --------------------------------------------------------------------------- + + +async def test_empty_pages_returns_zero_tuple(confluence_mocks): + confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = ( + [], + None, + ) + indexed, skipped, warning = await _run_index(confluence_mocks) + assert indexed == 0 + assert skipped == 0 + assert warning is None + confluence_mocks["batch_mock"].assert_not_called() + + +async def test_no_pages_error_message_returns_success_tuple(confluence_mocks): + confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = ( + [], + "No pages found in date range", + ) + indexed, skipped, warning = await _run_index(confluence_mocks) + assert indexed == 0 + assert skipped == 0 + assert warning is None + + +async def test_api_error_still_returns_3_tuple(confluence_mocks): + confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = ( + [], + "API exploded", + ) + result = await _run_index(confluence_mocks) + assert len(result) == 3 + assert result[0] == 0 + assert result[1] == 0 + assert "Failed to get Confluence pages" in result[2] + + +# --------------------------------------------------------------------------- +# Slice 7: Failed docs warning +# --------------------------------------------------------------------------- + + +async def test_failed_docs_warning_in_result(confluence_mocks): + confluence_mocks["batch_mock"].return_value = ([], 0, 2) + _, _, warning = await _run_index(confluence_mocks) + assert warning is not None + assert "2 failed" in warning diff --git a/surfsense_backend/tests/unit/connector_indexers/test_jira_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_jira_parallel.py new file mode 100644 index 000000000..e8699c158 --- /dev/null +++ b/surfsense_backend/tests/unit/connector_indexers/test_jira_parallel.py @@ -0,0 +1,372 @@ +"""Tests for Jira indexer migrated to the unified parallel pipeline.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import app.tasks.connector_indexers.jira_indexer as _mod +from app.db import DocumentType +from app.tasks.connector_indexers.jira_indexer import ( + _build_connector_doc, + index_jira_issues, +) + +pytestmark = pytest.mark.unit + +_USER_ID = "00000000-0000-0000-0000-000000000001" +_CONNECTOR_ID = 42 +_SEARCH_SPACE_ID = 1 + + +def _make_issue( + issue_key: str = "ENG-1", + issue_id: str = "10001", + title: str = "Fix login", +): + return {"key": issue_key, "id": issue_id, "title": title} + + +def _make_formatted_issue( + issue_key: str = "ENG-1", + issue_id: str = "10001", + title: str = "Fix login", + status: str = "In Progress", + priority: str = "High", + comments=None, +): + return { + "key": issue_key, + "id": issue_id, + "title": title, + "status": status, + "priority": priority, + "comments": comments or [], + } + + +# --------------------------------------------------------------------------- +# Slice 1: _build_connector_doc tracer bullet +# --------------------------------------------------------------------------- + + +async def test_build_connector_doc_produces_correct_fields(): + issue = _make_issue(issue_key="ENG-42", issue_id="4242", title="Fix auth bug") + formatted = _make_formatted_issue( + issue_key="ENG-42", + issue_id="4242", + title="Fix auth bug", + status="Done", + priority="Urgent", + comments=[{"id": "c1"}], + ) + markdown = "# ENG-42: Fix auth bug\n\nBody" + + doc = _build_connector_doc( + issue, + formatted, + markdown, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + assert doc.title == "ENG-42: 4242" + assert doc.unique_id == "ENG-42" + assert doc.document_type == DocumentType.JIRA_CONNECTOR + assert doc.source_markdown == markdown + assert doc.search_space_id == _SEARCH_SPACE_ID + assert doc.connector_id == _CONNECTOR_ID + assert doc.created_by_id == _USER_ID + assert doc.should_summarize is True + assert doc.metadata["issue_id"] == "ENG-42" + assert doc.metadata["issue_identifier"] == "ENG-42" + assert doc.metadata["issue_title"] == "4242" + assert doc.metadata["state"] == "Done" + assert doc.metadata["priority"] == "Urgent" + assert doc.metadata["comment_count"] == 1 + assert doc.metadata["connector_id"] == _CONNECTOR_ID + assert doc.metadata["document_type"] == "Jira Issue" + assert doc.metadata["connector_type"] == "Jira" + assert doc.fallback_summary is not None + assert "ENG-42" in doc.fallback_summary + assert markdown in doc.fallback_summary + + +async def test_build_connector_doc_summary_disabled(): + doc = _build_connector_doc( + _make_issue(), + _make_formatted_issue(), + "# content", + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=False, + ) + assert doc.should_summarize is False + + +# --------------------------------------------------------------------------- +# Shared fixtures for Slices 2-7 +# --------------------------------------------------------------------------- + + +def _mock_connector(enable_summary: bool = True): + c = MagicMock() + c.config = {"access_token": "tok"} + c.enable_summary = enable_summary + c.last_indexed_at = None + return c + + +def _mock_jira_client(issues=None, error=None): + client = MagicMock() + client.get_issues_by_date_range = AsyncMock( + return_value=(issues if issues is not None else [], error), + ) + client.format_issue = MagicMock( + side_effect=lambda i: _make_formatted_issue( + issue_key=i.get("key", ""), + issue_id=i.get("id", ""), + title=i.get("title", ""), + ) + ) + client.format_issue_to_markdown = MagicMock( + side_effect=lambda fi: f"# {fi.get('key', '')}: {fi.get('id', '')}\n\nContent" + ) + client.close = AsyncMock() + return client + + +@pytest.fixture +def jira_mocks(monkeypatch): + mock_session = AsyncMock() + mock_session.no_autoflush = MagicMock() + + mock_connector = _mock_connector() + monkeypatch.setattr( + _mod, "get_connector_by_id", AsyncMock(return_value=mock_connector), + ) + + jira_client = _mock_jira_client(issues=[_make_issue()]) + monkeypatch.setattr( + _mod, "JiraHistoryConnector", MagicMock(return_value=jira_client), + ) + + monkeypatch.setattr( + _mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None), + ) + monkeypatch.setattr( + _mod, "update_connector_last_indexed", AsyncMock(), + ) + monkeypatch.setattr( + _mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")), + ) + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_progress = AsyncMock() + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_failure = AsyncMock() + monkeypatch.setattr( + _mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger), + ) + + batch_mock = AsyncMock(return_value=([], 1, 0)) + pipeline_mock = MagicMock() + pipeline_mock.index_batch_parallel = batch_mock + pipeline_mock.migrate_legacy_docs = AsyncMock() + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock), + ) + + return { + "session": mock_session, + "connector": mock_connector, + "jira_client": jira_client, + "task_logger": mock_task_logger, + "pipeline_mock": pipeline_mock, + "batch_mock": batch_mock, + } + + +async def _run_index(mocks, **overrides): + return await index_jira_issues( + session=mocks["session"], + connector_id=overrides.get("connector_id", _CONNECTOR_ID), + search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID), + user_id=overrides.get("user_id", _USER_ID), + start_date=overrides.get("start_date", "2025-01-01"), + end_date=overrides.get("end_date", "2025-12-31"), + update_last_indexed=overrides.get("update_last_indexed", True), + on_heartbeat_callback=overrides.get("on_heartbeat_callback"), + ) + + +# --------------------------------------------------------------------------- +# Slice 2: Full pipeline wiring +# --------------------------------------------------------------------------- + + +async def test_one_issue_calls_pipeline_and_returns_indexed_count(jira_mocks): + indexed, skipped, warning = await _run_index(jira_mocks) + assert indexed == 1 + assert skipped == 0 + assert warning is None + + jira_mocks["batch_mock"].assert_called_once() + connector_docs = jira_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert connector_docs[0].document_type == DocumentType.JIRA_CONNECTOR + + +async def test_pipeline_called_with_max_concurrency_3(jira_mocks): + await _run_index(jira_mocks) + call_kwargs = jira_mocks["batch_mock"].call_args[1] + assert call_kwargs.get("max_concurrency") == 3 + + +async def test_migrate_legacy_docs_called_before_indexing(jira_mocks): + await _run_index(jira_mocks) + jira_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once() + + +# --------------------------------------------------------------------------- +# Slice 3: Issue skipping (missing key/title/content) +# --------------------------------------------------------------------------- + + +async def test_issues_with_missing_key_are_skipped(jira_mocks): + issues = [ + _make_issue(issue_key="ENG-1", issue_id="10001"), + {"key": "", "id": "10002", "title": "No key"}, + ] + jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None) + + _, skipped, _ = await _run_index(jira_mocks) + connector_docs = jira_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +async def test_issues_with_missing_title_are_skipped(jira_mocks): + issues = [ + _make_issue(issue_key="ENG-1", issue_id="10001"), + {"key": "ENG-2", "id": "", "title": "Missing id used as title"}, + ] + jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None) + + _, skipped, _ = await _run_index(jira_mocks) + connector_docs = jira_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +async def test_issues_with_no_content_are_skipped(jira_mocks): + issues = [ + _make_issue(issue_key="ENG-1", issue_id="10001"), + _make_issue(issue_key="ENG-2", issue_id="10002"), + ] + jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None) + + jira_mocks["jira_client"].format_issue_to_markdown.side_effect = [ + "# ENG-1: 10001\n\nContent", + "", + ] + _, skipped, _ = await _run_index(jira_mocks) + connector_docs = jira_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# Slice 4: Duplicate content skipping +# --------------------------------------------------------------------------- + + +async def test_duplicate_content_issues_are_skipped(jira_mocks, monkeypatch): + issues = [ + _make_issue(issue_key="ENG-1", issue_id="10001"), + _make_issue(issue_key="ENG-2", issue_id="10002"), + ] + jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None) + + call_count = 0 + + async def _check_dup(session, content_hash): + nonlocal call_count + call_count += 1 + if call_count == 2: + dup = MagicMock() + dup.id = 99 + dup.document_type = "OTHER" + return dup + return None + + monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup) + + _, skipped, _ = await _run_index(jira_mocks) + connector_docs = jira_mocks["batch_mock"].call_args[0][0] + assert len(connector_docs) == 1 + assert skipped == 1 + + +# --------------------------------------------------------------------------- +# Slice 5: Heartbeat callback forwarding +# --------------------------------------------------------------------------- + + +async def test_heartbeat_callback_forwarded_to_pipeline(jira_mocks): + heartbeat_cb = AsyncMock() + await _run_index(jira_mocks, on_heartbeat_callback=heartbeat_cb) + call_kwargs = jira_mocks["batch_mock"].call_args[1] + assert call_kwargs.get("on_heartbeat") is heartbeat_cb + + +# --------------------------------------------------------------------------- +# Slice 6: Empty issues and no-data success tuple +# --------------------------------------------------------------------------- + + +async def test_empty_issues_returns_zero_tuple(jira_mocks): + jira_mocks["jira_client"].get_issues_by_date_range.return_value = ([], None) + indexed, skipped, warning = await _run_index(jira_mocks) + assert indexed == 0 + assert skipped == 0 + assert warning is None + jira_mocks["batch_mock"].assert_not_called() + + +async def test_no_issues_error_message_returns_success_tuple(jira_mocks): + jira_mocks["jira_client"].get_issues_by_date_range.return_value = ( + [], + "No issues found in date range", + ) + indexed, skipped, warning = await _run_index(jira_mocks) + assert indexed == 0 + assert skipped == 0 + assert warning is None + + +async def test_api_error_still_returns_3_tuple(jira_mocks): + jira_mocks["jira_client"].get_issues_by_date_range.return_value = ( + [], + "API exploded", + ) + result = await _run_index(jira_mocks) + assert len(result) == 3 + assert result[0] == 0 + assert result[1] == 0 + assert "Failed to get Jira issues" in result[2] + + +# --------------------------------------------------------------------------- +# Slice 7: Failed docs warning +# --------------------------------------------------------------------------- + + +async def test_failed_docs_warning_in_result(jira_mocks): + jira_mocks["batch_mock"].return_value = ([], 0, 2) + _, _, warning = await _run_index(jira_mocks) + assert warning is not None + assert "2 failed" in warning From d2a4b238d7a675f868b01ab7251d8f4e9b9ddd6f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 19:25:03 +0530 Subject: [PATCH 18/31] feat: enhance Google Drive client with thread-safe download and export methods - Implemented per-thread HTTP transport for concurrent downloads to ensure thread safety. - Refactored `download_file` and `download_file_to_disk` methods to utilize blocking calls on separate threads, improving performance during file operations. - Added logging to track the start and end of download and export processes, providing better visibility into execution time. - Updated unit tests to verify parallel execution of download and export operations, ensuring efficiency in handling multiple requests. --- .../app/connectors/google_drive/client.py | 104 ++++++++++-------- .../google_drive/content_extractor.py | 5 +- .../test_google_drive_parallel.py | 83 +++++++++++++- 3 files changed, 142 insertions(+), 50 deletions(-) diff --git a/surfsense_backend/app/connectors/google_drive/client.py b/surfsense_backend/app/connectors/google_drive/client.py index 4e4240e91..fdbacfd69 100644 --- a/surfsense_backend/app/connectors/google_drive/client.py +++ b/surfsense_backend/app/connectors/google_drive/client.py @@ -140,6 +140,24 @@ class GoogleDriveClient: except Exception as e: return None, f"Error getting file metadata: {e!s}" + @staticmethod + def _sync_download_file(service, file_id: str) -> tuple[bytes | None, str | None]: + """Blocking download — runs on a worker thread via ``to_thread``.""" + try: + from googleapiclient.http import MediaIoBaseDownload + + request = service.files().get_media(fileId=file_id) + fh = io.BytesIO() + downloader = MediaIoBaseDownload(fh, request) + done = False + while not done: + _, done = downloader.next_chunk() + return fh.getvalue(), None + except HttpError as e: + return None, f"HTTP error downloading file: {e.resp.status}" + except Exception as e: + return None, f"Error downloading file: {e!s}" + async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]: """ Download binary file content. @@ -150,27 +168,28 @@ class GoogleDriveClient: Returns: Tuple of (file content bytes, error message) """ + service = await self.get_service() + return await asyncio.to_thread(self._sync_download_file, service, file_id) + + @staticmethod + def _sync_download_file_to_disk( + service, file_id: str, dest_path: str, chunksize: int, + ) -> str | None: + """Blocking download-to-disk — runs on a worker thread via ``to_thread``.""" try: - service = await self.get_service() - request = service.files().get_media(fileId=file_id) - - import io - - fh = io.BytesIO() from googleapiclient.http import MediaIoBaseDownload - downloader = MediaIoBaseDownload(fh, request) - - done = False - while not done: - _, done = downloader.next_chunk() - - return fh.getvalue(), None - + request = service.files().get_media(fileId=file_id) + with open(dest_path, "wb") as fh: + downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize) + done = False + while not done: + _, done = downloader.next_chunk() + return None except HttpError as e: - return None, f"HTTP error downloading file: {e.resp.status}" + return f"HTTP error downloading file: {e.resp.status}" except Exception as e: - return None, f"Error downloading file: {e!s}" + return f"Error downloading file: {e!s}" async def download_file_to_disk( self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024, @@ -179,23 +198,27 @@ class GoogleDriveClient: Returns error message on failure, None on success. """ + service = await self.get_service() + return await asyncio.to_thread( + self._sync_download_file_to_disk, service, file_id, dest_path, chunksize, + ) + + @staticmethod + def _sync_export_google_file( + service, file_id: str, mime_type: str, + ) -> tuple[bytes | None, str | None]: + """Blocking export — runs on a worker thread via ``to_thread``.""" try: - service = await self.get_service() - request = service.files().get_media(fileId=file_id) - from googleapiclient.http import MediaIoBaseDownload - - with open(dest_path, "wb") as fh: - downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize) - done = False - while not done: - _, done = downloader.next_chunk() - - return None - + content = ( + service.files().export(fileId=file_id, mimeType=mime_type).execute() + ) + if not isinstance(content, bytes): + content = content.encode("utf-8") + return content, None except HttpError as e: - return f"HTTP error downloading file: {e.resp.status}" + return None, f"HTTP error exporting file: {e.resp.status}" except Exception as e: - return f"Error downloading file: {e!s}" + return None, f"Error exporting file: {e!s}" async def export_google_file( self, file_id: str, mime_type: str @@ -210,23 +233,10 @@ class GoogleDriveClient: Returns: Tuple of (exported content as bytes, error message) """ - try: - service = await self.get_service() - content = ( - service.files().export(fileId=file_id, mimeType=mime_type).execute() - ) - - # Content is already bytes from the API - # Keep as bytes to support both text and binary formats (like PDF) - if not isinstance(content, bytes): - content = content.encode("utf-8") - - return content, None - - except HttpError as e: - return None, f"HTTP error exporting file: {e.resp.status}" - except Exception as e: - return None, f"Error exporting file: {e!s}" + service = await self.get_service() + return await asyncio.to_thread( + self._sync_export_google_file, service, file_id, mime_type, + ) async def create_file( self, diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 69f64d9ae..29c7e85a7 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -1,5 +1,6 @@ """Content extraction for Google Drive files.""" +import asyncio import logging import os import tempfile @@ -118,7 +119,7 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: ) if stt_service_type == "local": from app.services.stt_service import stt_service - result = stt_service.transcribe_file(file_path) + result = await asyncio.to_thread(stt_service.transcribe_file, file_path) text = result.get("text", "") else: with open(file_path, "rb") as audio_file: @@ -170,7 +171,7 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: from docling.document_converter import DocumentConverter converter = DocumentConverter() - result = converter.convert(file_path) + result = await asyncio.to_thread(converter.convert, file_path) return result.document.export_to_markdown() raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}") diff --git a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py index 1183efa9f..9737ca3d2 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py @@ -1,7 +1,8 @@ """Tests for parallel download + indexing in the Google Drive indexer.""" import asyncio -from unittest.mock import AsyncMock, MagicMock +import time +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -586,3 +587,83 @@ async def test_selected_files_skip_rename_counting(selected_files_mocks): call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2] assert len(call_files) == 2 assert {f["id"] for f in call_files} == {"n1", "n2"} + + +# --------------------------------------------------------------------------- +# asyncio.to_thread verification — prove blocking calls run in parallel +# --------------------------------------------------------------------------- + +async def test_client_download_file_runs_in_thread_parallel(): + """Calling download_file concurrently via asyncio.gather should overlap + blocking work on separate threads, proving to_thread is effective. + + Strategy: patch _sync_download_file with a blocking time.sleep(0.2). + Launch 3 concurrent calls. Serial would take >=0.6s; parallel < 0.4s. + """ + from app.connectors.google_drive.client import GoogleDriveClient + + BLOCK_SECONDS = 0.2 + NUM_CALLS = 3 + + def _blocking_download(service, file_id): + time.sleep(BLOCK_SECONDS) + return b"fake-content", None + + client = GoogleDriveClient.__new__(GoogleDriveClient) + client.service = MagicMock() + client._service_lock = asyncio.Lock() + + with patch.object( + GoogleDriveClient, "_sync_download_file", staticmethod(_blocking_download), + ): + start = time.monotonic() + results = await asyncio.gather( + *(client.download_file(f"file-{i}") for i in range(NUM_CALLS)) + ) + elapsed = time.monotonic() - start + + for content, error in results: + assert content == b"fake-content" + assert error is None + + serial_minimum = BLOCK_SECONDS * NUM_CALLS + assert elapsed < serial_minimum, ( + f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — " + f"downloads are not running in parallel" + ) + + +async def test_client_export_google_file_runs_in_thread_parallel(): + """Same strategy for export_google_file — verify to_thread parallelism.""" + from app.connectors.google_drive.client import GoogleDriveClient + + BLOCK_SECONDS = 0.2 + NUM_CALLS = 3 + + def _blocking_export(service, file_id, mime_type): + time.sleep(BLOCK_SECONDS) + return b"exported", None + + client = GoogleDriveClient.__new__(GoogleDriveClient) + client.service = MagicMock() + client._service_lock = asyncio.Lock() + + with patch.object( + GoogleDriveClient, "_sync_export_google_file", staticmethod(_blocking_export), + ): + start = time.monotonic() + results = await asyncio.gather( + *(client.export_google_file(f"file-{i}", "application/pdf") + for i in range(NUM_CALLS)) + ) + elapsed = time.monotonic() - start + + for content, error in results: + assert content == b"exported" + assert error is None + + serial_minimum = BLOCK_SECONDS * NUM_CALLS + assert elapsed < serial_minimum, ( + f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — " + f"exports are not running in parallel" + ) From 00934ff4629f65855ef6f81db0f8086a99bde94d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 19:25:45 +0530 Subject: [PATCH 19/31] feat: enhance Google Drive client with improved logging and thread-safe operations - Added logging to track the start and end of file download and export processes, improving visibility into execution time. - Implemented per-thread HTTP transport for concurrent downloads and exports, ensuring thread safety. - Refactored download and export methods to utilize resolved credentials, enhancing functionality. - Updated unit tests to validate the new threading and logging features, ensuring robust parallel execution. --- .../app/connectors/google_drive/client.py | 54 +++++++++++++++++-- .../google_drive/content_extractor.py | 8 +++ .../google_drive_indexer.py | 5 +- .../test_google_drive_parallel.py | 6 ++- 4 files changed, 65 insertions(+), 8 deletions(-) diff --git a/surfsense_backend/app/connectors/google_drive/client.py b/surfsense_backend/app/connectors/google_drive/client.py index fdbacfd69..4879d01bd 100644 --- a/surfsense_backend/app/connectors/google_drive/client.py +++ b/surfsense_backend/app/connectors/google_drive/client.py @@ -2,9 +2,14 @@ import asyncio import io +import logging +import threading +import time from typing import Any +import httplib2 from google.oauth2.credentials import Credentials +from google_auth_httplib2 import AuthorizedHttp from googleapiclient.discovery import build from googleapiclient.errors import HttpError from googleapiclient.http import MediaIoBaseUpload @@ -13,6 +18,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from .credentials import get_valid_credentials from .file_types import GOOGLE_DOC, GOOGLE_SHEET +logger = logging.getLogger(__name__) + + +def _build_thread_http(credentials: Credentials) -> AuthorizedHttp: + """Create a per-thread HTTP transport so concurrent downloads don't share + the same ``httplib2.Http`` (which is not thread-safe).""" + return AuthorizedHttp(credentials, http=httplib2.Http()) + class GoogleDriveClient: """Client for Google Drive API operations.""" @@ -35,6 +48,7 @@ class GoogleDriveClient: self.session = session self.connector_id = connector_id self._credentials = credentials + self._resolved_credentials: Credentials | None = None self.service = None self._service_lock = asyncio.Lock() @@ -62,6 +76,7 @@ class GoogleDriveClient: credentials = await get_valid_credentials( self.session, self.connector_id ) + self._resolved_credentials = credentials self.service = build("drive", "v3", credentials=credentials) return self.service except Exception as e: @@ -141,12 +156,19 @@ class GoogleDriveClient: return None, f"Error getting file metadata: {e!s}" @staticmethod - def _sync_download_file(service, file_id: str) -> tuple[bytes | None, str | None]: + def _sync_download_file( + service, file_id: str, credentials: Credentials, + ) -> tuple[bytes | None, str | None]: """Blocking download — runs on a worker thread via ``to_thread``.""" + thread = threading.current_thread().name + t0 = time.monotonic() + logger.info(f"[download] START file={file_id} thread={thread}") try: from googleapiclient.http import MediaIoBaseDownload + http = _build_thread_http(credentials) request = service.files().get_media(fileId=file_id) + request.http = http fh = io.BytesIO() downloader = MediaIoBaseDownload(fh, request) done = False @@ -157,6 +179,8 @@ class GoogleDriveClient: return None, f"HTTP error downloading file: {e.resp.status}" except Exception as e: return None, f"Error downloading file: {e!s}" + finally: + logger.info(f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s") async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]: """ @@ -169,17 +193,25 @@ class GoogleDriveClient: Tuple of (file content bytes, error message) """ service = await self.get_service() - return await asyncio.to_thread(self._sync_download_file, service, file_id) + return await asyncio.to_thread( + self._sync_download_file, service, file_id, self._resolved_credentials, + ) @staticmethod def _sync_download_file_to_disk( service, file_id: str, dest_path: str, chunksize: int, + credentials: Credentials, ) -> str | None: """Blocking download-to-disk — runs on a worker thread via ``to_thread``.""" + thread = threading.current_thread().name + t0 = time.monotonic() + logger.info(f"[download-to-disk] START file={file_id} thread={thread}") try: from googleapiclient.http import MediaIoBaseDownload + http = _build_thread_http(credentials) request = service.files().get_media(fileId=file_id) + request.http = http with open(dest_path, "wb") as fh: downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize) done = False @@ -190,6 +222,8 @@ class GoogleDriveClient: return f"HTTP error downloading file: {e.resp.status}" except Exception as e: return f"Error downloading file: {e!s}" + finally: + logger.info(f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s") async def download_file_to_disk( self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024, @@ -200,17 +234,24 @@ class GoogleDriveClient: """ service = await self.get_service() return await asyncio.to_thread( - self._sync_download_file_to_disk, service, file_id, dest_path, chunksize, + self._sync_download_file_to_disk, + service, file_id, dest_path, chunksize, self._resolved_credentials, ) @staticmethod def _sync_export_google_file( - service, file_id: str, mime_type: str, + service, file_id: str, mime_type: str, credentials: Credentials, ) -> tuple[bytes | None, str | None]: """Blocking export — runs on a worker thread via ``to_thread``.""" + thread = threading.current_thread().name + t0 = time.monotonic() + logger.info(f"[export] START file={file_id} thread={thread}") try: + http = _build_thread_http(credentials) content = ( - service.files().export(fileId=file_id, mimeType=mime_type).execute() + service.files() + .export(fileId=file_id, mimeType=mime_type) + .execute(http=http) ) if not isinstance(content, bytes): content = content.encode("utf-8") @@ -219,6 +260,8 @@ class GoogleDriveClient: return None, f"HTTP error exporting file: {e.resp.status}" except Exception as e: return None, f"Error exporting file: {e!s}" + finally: + logger.info(f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s") async def export_google_file( self, file_id: str, mime_type: str @@ -236,6 +279,7 @@ class GoogleDriveClient: service = await self.get_service() return await asyncio.to_thread( self._sync_export_google_file, service, file_id, mime_type, + self._resolved_credentials, ) async def create_file( diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 29c7e85a7..0903aea9f 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -4,6 +4,8 @@ import asyncio import logging import os import tempfile +import threading +import time from pathlib import Path from typing import Any @@ -119,7 +121,10 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: ) if stt_service_type == "local": from app.services.stt_service import stt_service + t0 = time.monotonic() + logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}") result = await asyncio.to_thread(stt_service.transcribe_file, file_path) + logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s") text = result.get("text", "") else: with open(file_path, "rb") as audio_file: @@ -171,7 +176,10 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: from docling.document_converter import DocumentConverter converter = DocumentConverter() + t0 = time.monotonic() + logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}") result = await asyncio.to_thread(converter.convert, file_path) + logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s") return result.document.export_to_markdown() raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}") diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 2d3139343..d67665d99 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -199,7 +199,7 @@ async def _download_files_parallel( search_space_id: int, user_id: str, enable_summary: bool, - max_concurrency: int = 5, + max_concurrency: int = 3, on_heartbeat: HeartbeatCallbackType | None = None, ) -> tuple[list[ConnectorDocument], int]: """Download and ETL files in parallel, returning ConnectorDocuments. @@ -219,6 +219,9 @@ async def _download_files_parallel( drive_client, file ) if error or not markdown: + file_name = file.get("name", "Unknown") + reason = error or "empty content" + logger.warning(f"Download/ETL failed for {file_name}: {reason}") return None doc = _build_connector_doc( file, diff --git a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py index 9737ca3d2..1523da0ed 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py @@ -605,12 +605,13 @@ async def test_client_download_file_runs_in_thread_parallel(): BLOCK_SECONDS = 0.2 NUM_CALLS = 3 - def _blocking_download(service, file_id): + def _blocking_download(service, file_id, credentials): time.sleep(BLOCK_SECONDS) return b"fake-content", None client = GoogleDriveClient.__new__(GoogleDriveClient) client.service = MagicMock() + client._resolved_credentials = MagicMock() client._service_lock = asyncio.Lock() with patch.object( @@ -640,12 +641,13 @@ async def test_client_export_google_file_runs_in_thread_parallel(): BLOCK_SECONDS = 0.2 NUM_CALLS = 3 - def _blocking_export(service, file_id, mime_type): + def _blocking_export(service, file_id, mime_type, credentials): time.sleep(BLOCK_SECONDS) return b"exported", None client = GoogleDriveClient.__new__(GoogleDriveClient) client.service = MagicMock() + client._resolved_credentials = MagicMock() client._service_lock = asyncio.Lock() with patch.object( From 3ce831d01d640226b90868430b31314d20a1a9e3 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 19:28:34 +0530 Subject: [PATCH 20/31] feat: reset indexing configurations in connector dialog --- .../connector-popup/hooks/use-connector-dialog.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index 7b4aa29fb..03d8a8fb4 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -867,6 +867,9 @@ export const useConnectorDialog = () => { setIsOpen(false); setIsFromOAuth(false); + setIndexingConfig(null); + setIndexingConnector(null); + setIndexingConnectorConfig(null); refreshConnectors(); queryClient.invalidateQueries({ @@ -898,6 +901,9 @@ export const useConnectorDialog = () => { const handleSkipIndexing = useCallback(() => { setIsOpen(false); setIsFromOAuth(false); + setIndexingConfig(null); + setIndexingConnector(null); + setIndexingConnectorConfig(null); }, [setIsOpen]); // Handle starting edit mode From 4e0749f9070972a37067407a90954dca50ab0249 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 20:01:08 +0530 Subject: [PATCH 21/31] fix: update file skipping logic for failed documents in Google Drive indexer - Modified the `_should_skip_file` function to skip previously failed documents during processing, improving error handling. - Updated the corresponding test to reflect the new behavior, ensuring that failed documents are correctly identified and skipped during automatic sync. --- .../app/tasks/connector_indexers/google_drive_indexer.py | 2 +- .../integration/indexing_pipeline/test_drive_pipeline.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index d67665d99..74101ed74 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -150,7 +150,7 @@ async def _should_skip_file( return True, f"File renamed: '{old_name}' → '{file_name}'" if not DocumentStatus.is_state(existing.status, DocumentStatus.READY): - return False, None + return True, "skipped (previously failed)" return True, "unchanged" diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py index 77128ebd9..2fffa9053 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_drive_pipeline.py @@ -110,10 +110,10 @@ async def test_drive_legacy_doc_migrated( assert row.unique_identifier_hash == native_hash -async def test_should_skip_file_does_not_skip_failed_document( +async def test_should_skip_file_skips_failed_document( db_session, db_search_space, db_user, ): - """A FAILED document with unchanged md5 must NOT be skipped — it needs reprocessing.""" + """A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index.""" import importlib import sys import types @@ -164,6 +164,7 @@ async def test_should_skip_file_does_not_skip_failed_document( incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5} - should_skip, _msg = await _should_skip_file(db_session, incoming_file, space_id) + should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id) - assert not should_skip, "FAILED documents must not be skipped even when content is unchanged" + assert should_skip, "FAILED documents must be skipped during automatic sync" + assert "failed" in msg.lower() From 3da0ffd68334235421c9b9696f66f20dedd4b717 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 21:47:14 +0530 Subject: [PATCH 22/31] feat: add native Excel parsing and improve Google Drive content extraction - Introduced a new utility for parsing .xlsx files into markdown format, enhancing the ability to process Excel documents natively. - Updated the Google Drive content extractor to utilize the new Excel parsing functionality, allowing for better handling of spreadsheet files. - Enhanced file type detection and export logic to support various document formats, improving overall content extraction accuracy. - Added unit tests to ensure the correctness of the new Excel parsing feature and its integration with existing content extraction workflows. --- .../google_drive/content_extractor.py | 45 ++++-- .../app/connectors/google_drive/file_types.py | 25 +++- .../document_processors/file_processors.py | 78 +++++++++++ surfsense_backend/app/utils/office_parsers.py | 72 ++++++++++ surfsense_backend/pyproject.toml | 1 + .../tests/unit/test_office_parsers.py | 129 ++++++++++++++++++ surfsense_backend/uv.lock | 101 +++++++------- 7 files changed, 390 insertions(+), 61 deletions(-) create mode 100644 surfsense_backend/app/utils/office_parsers.py create mode 100644 surfsense_backend/tests/unit/test_office_parsers.py diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 0903aea9f..272a71403 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -14,8 +14,15 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import Log from app.services.task_logging_service import TaskLoggingService +from app.utils.office_parsers import EXCEL_EXTENSIONS + from .client import GoogleDriveClient -from .file_types import get_export_mime_type, is_google_workspace_file, should_skip_file +from .file_types import ( + get_export_mime_type, + get_extension_from_mime, + is_google_workspace_file, + should_skip_file, +) logger = logging.getLogger(__name__) @@ -58,29 +65,30 @@ async def download_and_extract_content( if "md5Checksum" in file: drive_metadata["md5_checksum"] = file["md5Checksum"] if is_google_workspace_file(mime_type): - drive_metadata["exported_as"] = "pdf" + export_ext = get_extension_from_mime(get_export_mime_type(mime_type) or "") + drive_metadata["exported_as"] = export_ext.lstrip(".") if export_ext else "pdf" drive_metadata["original_workspace_type"] = mime_type.split(".")[-1] temp_file_path = None try: if is_google_workspace_file(mime_type): - # Workspace files (Docs/Sheets/Slides) use export -- returns bytes - # in one shot. These are typically small (a few MB as PDF/text). export_mime = get_export_mime_type(mime_type) if not export_mime: return None, drive_metadata, f"Cannot export Google Workspace type: {mime_type}" content_bytes, error = await client.export_google_file(file_id, export_mime) if error: return None, drive_metadata, error - extension = ".pdf" if export_mime == "application/pdf" else ".txt" + extension = get_extension_from_mime(export_mime) or ".pdf" with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp: tmp.write(content_bytes) temp_file_path = tmp.name else: - # Binary files -- stream directly to disk in chunks to avoid - # loading the entire file into memory. - extension = Path(file_name).suffix or ".bin" + extension = ( + Path(file_name).suffix + or get_extension_from_mime(mime_type) + or ".bin" + ) with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp: temp_file_path = tmp.name @@ -142,6 +150,11 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: raise ValueError("Transcription returned empty text") return f"# Transcription of {filename}\n\n{text}" + if lower.endswith(EXCEL_EXTENSIONS): + from app.utils.office_parsers import parse_excel_to_markdown + + return await parse_excel_to_markdown(file_path, filename) + # Document files -- use configured ETL service from app.config import config as app_config @@ -236,14 +249,17 @@ async def download_and_process_file( if error: return None, error - extension = ".pdf" if export_mime == "application/pdf" else ".txt" + extension = get_extension_from_mime(export_mime) or ".pdf" else: content_bytes, error = await client.download_file(file_id) if error: return None, error - # Preserve original file extension - extension = Path(file_name).suffix or ".bin" + extension = ( + Path(file_name).suffix + or get_extension_from_mime(mime_type) + or ".bin" + ) with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file: tmp_file.write(content_bytes) @@ -281,7 +297,12 @@ async def download_and_process_file( connector_info["metadata"]["md5_checksum"] = file["md5Checksum"] if is_google_workspace_file(mime_type): - connector_info["metadata"]["exported_as"] = "pdf" + export_ext = get_extension_from_mime( + get_export_mime_type(mime_type) or "" + ) + connector_info["metadata"]["exported_as"] = ( + export_ext.lstrip(".") if export_ext else "pdf" + ) connector_info["metadata"]["original_workspace_type"] = mime_type.split( "." )[-1] diff --git a/surfsense_backend/app/connectors/google_drive/file_types.py b/surfsense_backend/app/connectors/google_drive/file_types.py index a66463208..dd3456901 100644 --- a/surfsense_backend/app/connectors/google_drive/file_types.py +++ b/surfsense_backend/app/connectors/google_drive/file_types.py @@ -8,10 +8,33 @@ GOOGLE_SHORTCUT = "application/vnd.google-apps.shortcut" EXPORT_FORMATS = { GOOGLE_DOC: "application/pdf", - GOOGLE_SHEET: "application/pdf", + GOOGLE_SHEET: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", GOOGLE_SLIDE: "application/pdf", } +MIME_TO_EXTENSION: dict[str, str] = { + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", + "application/vnd.ms-excel": ".xls", + "application/msword": ".doc", + "application/vnd.ms-powerpoint": ".ppt", + "text/plain": ".txt", + "text/csv": ".csv", + "text/html": ".html", + "text/markdown": ".md", + "application/json": ".json", + "application/xml": ".xml", + "image/png": ".png", + "image/jpeg": ".jpg", +} + + +def get_extension_from_mime(mime_type: str) -> str | None: + """Return a file extension (with leading dot) for a MIME type, or None.""" + return MIME_TO_EXTENSION.get(mime_type) + def is_google_workspace_file(mime_type: str) -> bool: """Check if file is a Google Workspace file that needs export.""" diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index 6c0ae1870..c69c6fa95 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -1134,6 +1134,59 @@ async def process_file_in_background( ) return None + elif filename.lower().endswith((".xlsx",)): + from app.utils.office_parsers import parse_excel_to_markdown + + if notification: + await ( + NotificationService.document_processing.notify_processing_progress( + session, + notification, + stage="parsing", + stage_message="Parsing spreadsheet", + ) + ) + + await task_logger.log_task_progress( + log_entry, + f"Processing Excel file natively: {filename}", + {"file_type": "excel", "processing_stage": "native_parse"}, + ) + + excel_markdown = await parse_excel_to_markdown(file_path, filename) + + try: + os.unlink(file_path) + except Exception as e: + print("Error deleting temp file", e) + + result = await add_received_markdown_file_document( + session, filename, excel_markdown, search_space_id, user_id, connector + ) + + if connector: + await _update_document_from_connector(result, connector, session) + + if result: + await task_logger.log_task_success( + log_entry, + f"Successfully parsed and processed Excel file: {filename}", + { + "document_id": result.id, + "content_hash": result.content_hash, + "file_type": "excel", + "etl_service": "NATIVE_EXCEL", + }, + ) + return result + else: + await task_logger.log_task_success( + log_entry, + f"Excel file already exists (duplicate): {filename}", + {"duplicate_detected": True, "file_type": "excel"}, + ) + return None + else: # Import page limit service from app.services.page_limit_service import ( @@ -1797,6 +1850,31 @@ async def process_file_in_background_with_document( with contextlib.suppress(Exception): os.unlink(file_path) + elif filename.lower().endswith((".xlsx",)): + from app.utils.office_parsers import parse_excel_to_markdown + + if notification: + await ( + NotificationService.document_processing.notify_processing_progress( + session, + notification, + stage="parsing", + stage_message="Parsing spreadsheet", + ) + ) + + await task_logger.log_task_progress( + log_entry, + f"Processing Excel file natively: {filename}", + {"file_type": "excel", "processing_stage": "native_parse"}, + ) + + markdown_content = await parse_excel_to_markdown(file_path, filename) + etl_service = "NATIVE_EXCEL" + + with contextlib.suppress(Exception): + os.unlink(file_path) + else: # Document files - use ETL service from app.services.page_limit_service import ( diff --git a/surfsense_backend/app/utils/office_parsers.py b/surfsense_backend/app/utils/office_parsers.py new file mode 100644 index 000000000..a1550e110 --- /dev/null +++ b/surfsense_backend/app/utils/office_parsers.py @@ -0,0 +1,72 @@ +"""Native parsers for Office file formats.""" + +import asyncio +import logging +import threading +import time +from pathlib import Path + +logger = logging.getLogger(__name__) + +EXCEL_EXTENSIONS = (".xlsx",) + + +def _parse_excel_sync(file_path: str) -> str: + """Parse an .xlsx file into markdown tables (synchronous).""" + from openpyxl import load_workbook + + wb = load_workbook(file_path, read_only=True, data_only=True) + markdown_parts: list[str] = [] + + for sheet_name in wb.sheetnames: + ws = wb[sheet_name] + rows = list(ws.iter_rows(values_only=True)) + non_empty_rows = [r for r in rows if any(c is not None for c in r)] + if not non_empty_rows: + continue + + markdown_parts.append(f"## {sheet_name}\n") + max_cols = max(len(row) for row in non_empty_rows) + + header = non_empty_rows[0] + hdr = [str(c if c is not None else "") for c in header] + hdr.extend([""] * (max_cols - len(hdr))) + markdown_parts.append("| " + " | ".join(hdr) + " |") + markdown_parts.append("| " + " | ".join("---" for _ in range(max_cols)) + " |") + + for row in non_empty_rows[1:]: + cells = [str(c if c is not None else "") for c in row] + cells.extend([""] * (max_cols - len(cells))) + markdown_parts.append("| " + " | ".join(cells) + " |") + + markdown_parts.append("") + + wb.close() + return "\n".join(markdown_parts) + + +async def parse_excel_to_markdown(file_path: str, filename: str = "") -> str: + """Parse an .xlsx file into markdown tables (async wrapper). + + Raises ``ValueError`` if no data is found in the workbook. + """ + t0 = time.monotonic() + logger.info( + "[excel-parse] START file=%s thread=%s", + filename, + threading.current_thread().name, + ) + + result = await asyncio.to_thread(_parse_excel_sync, file_path) + + logger.info( + "[excel-parse] END file=%s elapsed=%.2fs", + filename, + time.monotonic() - t0, + ) + + if not result.strip(): + raise ValueError(f"No data found in Excel file: {filename or file_path}") + + title = f"# {filename}\n\n" if filename else "" + return title + result diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 017994c75..724e6db4c 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -73,6 +73,7 @@ dependencies = [ "langchain-daytona>=0.0.2", "pypandoc>=1.16.2", "notion-markdown>=0.7.0", + "openpyxl>=3.1.5", ] [dependency-groups] diff --git a/surfsense_backend/tests/unit/test_office_parsers.py b/surfsense_backend/tests/unit/test_office_parsers.py new file mode 100644 index 000000000..11429a71d --- /dev/null +++ b/surfsense_backend/tests/unit/test_office_parsers.py @@ -0,0 +1,129 @@ +"""Unit tests for native Office file parsers (no DB, no external services).""" + +import tempfile + +import pytest +from openpyxl import Workbook + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _create_xlsx(sheets: dict[str, list[list]]) -> str: + """Create a real .xlsx file on disk and return its path. + + ``sheets`` maps sheet name -> list of rows, where each row is a list of + cell values. + """ + wb = Workbook() + first = True + for name, rows in sheets.items(): + ws = wb.active if first else wb.create_sheet(title=name) + if first: + ws.title = name + first = False + for row in rows: + ws.append(row) + tmp = tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) + wb.save(tmp.name) + wb.close() + tmp.close() + return tmp.name + + +# --------------------------------------------------------------------------- +# Tracer bullet: cell values appear in markdown +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_parse_excel_produces_markdown_with_cell_values(): + """A single-sheet .xlsx with known data produces markdown containing those values.""" + from app.utils.office_parsers import parse_excel_to_markdown + + path = _create_xlsx( + {"Sales": [["Product", "Revenue"], ["Widget", 1500], ["Gadget", 3200]]} + ) + + md = await parse_excel_to_markdown(path, filename="report.xlsx") + + assert "Product" in md + assert "Revenue" in md + assert "Widget" in md + assert "1500" in md + assert "Gadget" in md + assert "3200" in md + assert "report.xlsx" in md + assert "|" in md + + +# --------------------------------------------------------------------------- +# Multi-sheet workbooks include all sheets +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_parse_excel_includes_all_sheets(): + """Both sheet names and their data appear in the output.""" + from app.utils.office_parsers import parse_excel_to_markdown + + path = _create_xlsx( + { + "Inventory": [["Item", "Qty"], ["Bolts", 200]], + "Pricing": [["Item", "Price"], ["Bolts", 4.50]], + } + ) + + md = await parse_excel_to_markdown(path, filename="multi.xlsx") + + assert "Inventory" in md + assert "Pricing" in md + assert "Bolts" in md + assert "200" in md + assert "4.5" in md + + +# --------------------------------------------------------------------------- +# Empty spreadsheet raises ValueError +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_parse_excel_raises_on_empty_file(): + """An .xlsx with no data raises ValueError.""" + from app.utils.office_parsers import parse_excel_to_markdown + + wb = Workbook() + tmp = tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) + wb.save(tmp.name) + wb.close() + tmp.close() + + with pytest.raises(ValueError, match="No data found"): + await parse_excel_to_markdown(tmp.name, filename="empty.xlsx") + + +# --------------------------------------------------------------------------- +# _parse_file_to_markdown routes .xlsx natively (no ETL call) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_parse_file_to_markdown_routes_xlsx_natively(): + """content_extractor._parse_file_to_markdown uses native parser for .xlsx.""" + from app.connectors.google_drive.content_extractor import _parse_file_to_markdown + + path = _create_xlsx( + {"Data": [["Name", "Score"], ["Alice", 95], ["Bob", 82]]} + ) + + md = await _parse_file_to_markdown(path, "grades.xlsx") + + assert "Alice" in md + assert "95" in md + assert "Bob" in md + assert "82" in md diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 2770c659a..e4d148b50 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -1171,7 +1171,7 @@ name = "contourpy" version = "1.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/01/1253e6698a07380cd31a736d248a3f2a50a7c88779a1813da27503cadc2a/contourpy-1.3.3.tar.gz", hash = "sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880", size = 13466174, upload-time = "2025-07-26T12:03:12.549Z" } wheels = [ @@ -2596,6 +2596,7 @@ dependencies = [ { name = "griffecli" }, { name = "griffelib" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/04/56/28a0accac339c164b52a92c6cfc45a903acc0c174caa5c1713803467b533/griffe-2.0.0.tar.gz", hash = "sha256:c68979cd8395422083a51ea7cf02f9c119d889646d99b7b656ee43725de1b80f", size = 293906, upload-time = "2026-03-23T21:06:53.402Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/8b/94/ee21d41e7eb4f823b94603b9d40f86d3c7fde80eacc2c3c71845476dddaa/griffe-2.0.0-py3-none-any.whl", hash = "sha256:5418081135a391c3e6e757a7f3f156f1a1a746cc7b4023868ff7d5e2f9a980aa", size = 5214, upload-time = "2026-02-09T19:09:44.105Z" }, ] @@ -2608,6 +2609,7 @@ dependencies = [ { name = "colorama" }, { name = "griffelib" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/a4/f8/2e129fd4a86e52e58eefe664de05e7d502decf766e7316cc9e70fdec3e18/griffecli-2.0.0.tar.gz", hash = "sha256:312fa5ebb4ce6afc786356e2d0ce85b06c1c20d45abc42d74f0cda65e159f6ef", size = 56213, upload-time = "2026-03-23T21:06:54.8Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ed/d93f7a447bbf7a935d8868e9617cbe1cadf9ee9ee6bd275d3040fbf93d60/griffecli-2.0.0-py3-none-any.whl", hash = "sha256:9f7cd9ee9b21d55e91689358978d2385ae65c22f307a63fb3269acf3f21e643d", size = 9345, upload-time = "2026-02-09T19:09:42.554Z" }, ] @@ -2616,6 +2618,7 @@ wheels = [ name = "griffelib" version = "2.0.0" source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/06/eccbd311c9e2b3ca45dbc063b93134c57a1ccc7607c5e545264ad092c4a9/griffelib-2.0.0.tar.gz", hash = "sha256:e504d637a089f5cab9b5daf18f7645970509bf4f53eda8d79ed71cce8bd97934", size = 166312, upload-time = "2026-03-23T21:06:55.954Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, ] @@ -4082,15 +4085,15 @@ name = "matplotlib" version = "3.10.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "contourpy" }, - { name = "cycler" }, - { name = "fonttools" }, - { name = "kiwisolver" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "pillow" }, - { name = "pyparsing" }, - { name = "python-dateutil" }, + { name = "contourpy", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "cycler", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "fonttools", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "kiwisolver", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "numpy", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "packaging", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "pillow", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "pyparsing", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "python-dateutil", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/8a/76/d3c6e3a13fe484ebe7718d14e269c9569c4eb0020a968a327acb3b9a8fe6/matplotlib-3.10.8.tar.gz", hash = "sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3", size = 34806269, upload-time = "2025-12-10T22:56:51.155Z" } wheels = [ @@ -4201,7 +4204,7 @@ name = "ml-dtypes" version = "0.5.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" } wheels = [ @@ -4967,9 +4970,9 @@ name = "ocrmac" version = "1.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click" }, - { name = "pillow" }, - { name = "pyobjc-framework-vision" }, + { name = "click", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "pillow", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-vision", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/5e/07/3e15ab404f75875c5e48c47163300eb90b7409044d8711fc3aaf52503f2e/ocrmac-1.0.1.tar.gz", hash = "sha256:507fe5e4cbd67b2d03f6729a52bbc11f9d0b58241134eb958a5daafd4b9d93d9", size = 1454317, upload-time = "2026-01-08T16:44:26.412Z" } wheels = [ @@ -5003,10 +5006,10 @@ name = "onnx" version = "1.20.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ml-dtypes" }, - { name = "numpy" }, - { name = "protobuf" }, - { name = "typing-extensions" }, + { name = "ml-dtypes", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "numpy", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "protobuf", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3b/8a/335c03a8683a88a32f9a6bb98899ea6df241a41df64b37b9696772414794/onnx-1.20.1.tar.gz", hash = "sha256:ded16de1df563d51fbc1ad885f2a426f814039d8b5f4feb77febe09c0295ad67", size = 12048980, upload-time = "2026-01-10T01:40:03.043Z" } wheels = [ @@ -6493,7 +6496,7 @@ name = "pyobjc-framework-cocoa" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core" }, + { name = "pyobjc-core", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/02/a3/16ca9a15e77c061a9250afbae2eae26f2e1579eb8ca9462ae2d2c71e1169/pyobjc_framework_cocoa-12.1.tar.gz", hash = "sha256:5556c87db95711b985d5efdaaf01c917ddd41d148b1e52a0c66b1a2e2c5c1640", size = 2772191, upload-time = "2025-11-14T10:13:02.069Z" } wheels = [ @@ -6509,8 +6512,8 @@ name = "pyobjc-framework-coreml" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core" }, - { name = "pyobjc-framework-cocoa" }, + { name = "pyobjc-core", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-cocoa", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/30/2d/baa9ea02cbb1c200683cb7273b69b4bee5070e86f2060b77e6a27c2a9d7e/pyobjc_framework_coreml-12.1.tar.gz", hash = "sha256:0d1a4216891a18775c9e0170d908714c18e4f53f9dc79fb0f5263b2aa81609ba", size = 40465, upload-time = "2025-11-14T10:14:02.265Z" } wheels = [ @@ -6526,8 +6529,8 @@ name = "pyobjc-framework-quartz" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core" }, - { name = "pyobjc-framework-cocoa" }, + { name = "pyobjc-core", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-cocoa", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/94/18/cc59f3d4355c9456fc945eae7fe8797003c4da99212dd531ad1b0de8a0c6/pyobjc_framework_quartz-12.1.tar.gz", hash = "sha256:27f782f3513ac88ec9b6c82d9767eef95a5cf4175ce88a1e5a65875fee799608", size = 3159099, upload-time = "2025-11-14T10:21:24.31Z" } wheels = [ @@ -6543,10 +6546,10 @@ name = "pyobjc-framework-vision" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core" }, - { name = "pyobjc-framework-cocoa" }, - { name = "pyobjc-framework-coreml" }, - { name = "pyobjc-framework-quartz" }, + { name = "pyobjc-core", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-cocoa", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-coreml", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-quartz", marker = "(python_full_version < '3.13' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c2/5a/08bb3e278f870443d226c141af14205ff41c0274da1e053b72b11dfc9fb2/pyobjc_framework_vision-12.1.tar.gz", hash = "sha256:a30959100e85dcede3a786c544e621ad6eb65ff6abf85721f805822b8c5fe9b0", size = 59538, upload-time = "2025-11-14T10:23:21.979Z" } wheels = [ @@ -7916,6 +7919,7 @@ dependencies = [ { name = "notion-client" }, { name = "notion-markdown" }, { name = "numpy" }, + { name = "openpyxl" }, { name = "pgvector" }, { name = "playwright" }, { name = "psycopg", extra = ["binary", "pool"] }, @@ -7998,6 +8002,7 @@ requires-dist = [ { name = "notion-client", specifier = ">=2.3.0" }, { name = "notion-markdown", specifier = ">=0.7.0" }, { name = "numpy", specifier = ">=1.24.0" }, + { name = "openpyxl", specifier = ">=3.1.5" }, { name = "pgvector", specifier = ">=0.3.6" }, { name = "playwright", specifier = ">=1.50.0" }, { name = "psycopg", extras = ["binary", "pool"], specifier = ">=3.3.2" }, @@ -8188,11 +8193,11 @@ name = "timm" version = "1.0.25" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "huggingface-hub" }, - { name = "pyyaml" }, - { name = "safetensors" }, - { name = "torch" }, - { name = "torchvision" }, + { name = "huggingface-hub", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "pyyaml", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "safetensors", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "torch", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "torchvision", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d7/2c/593109822fe735e637382aca6640c1102c19797f7791f1fd1dab2d6c3cb1/timm-1.0.25.tar.gz", hash = "sha256:47f59fc2754725735cc81bb83bcbfce5bec4ebd5d4bb9e69da57daa92fcfa768", size = 2414743, upload-time = "2026-02-23T16:49:00.137Z" } wheels = [ @@ -8819,22 +8824,22 @@ name = "unstructured-inference" version = "1.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "accelerate" }, - { name = "huggingface-hub" }, - { name = "matplotlib" }, - { name = "numpy" }, - { name = "onnx" }, - { name = "onnxruntime" }, - { name = "opencv-python" }, - { name = "pandas" }, - { name = "pdfminer-six" }, - { name = "pypdfium2" }, - { name = "python-multipart" }, - { name = "rapidfuzz" }, - { name = "scipy" }, - { name = "timm" }, - { name = "torch" }, - { name = "transformers" }, + { name = "accelerate", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "huggingface-hub", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "matplotlib", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "numpy", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "onnx", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "onnxruntime", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "opencv-python", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "pandas", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "pdfminer-six", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "pypdfium2", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "python-multipart", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "rapidfuzz", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "scipy", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "timm", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "torch", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, + { name = "transformers", marker = "python_full_version < '3.13' or sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ce/10/8f3bccfa9f1e0101a402ae1f529e07876541c6b18004747f0e793ed41f9e/unstructured_inference-1.2.0.tar.gz", hash = "sha256:19ca28512f3649c70a759cf2a4e98663e942a1b83c1acdb9506b0445f4862f23", size = 45732, upload-time = "2026-01-30T20:57:58.019Z" } wheels = [ From dff8a1df37c38d6cba8240b95408a6c9af2cfed2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 22:00:31 +0530 Subject: [PATCH 23/31] feat: add descendant checking for folder filtering in Google Drive changes --- .../connectors/google_drive/change_tracker.py | 56 +++++++++++++------ 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/surfsense_backend/app/connectors/google_drive/change_tracker.py b/surfsense_backend/app/connectors/google_drive/change_tracker.py index 3b7f804ac..f9a67e9df 100644 --- a/surfsense_backend/app/connectors/google_drive/change_tracker.py +++ b/surfsense_backend/app/connectors/google_drive/change_tracker.py @@ -84,22 +84,50 @@ async def get_changes( return [], None, f"Error getting changes: {e!s}" +async def _is_descendant_of( + client: GoogleDriveClient, + parent_ids: list[str], + target_folder_id: str, + max_depth: int = 20, +) -> bool: + """Walk up the parent chain to check if any ancestor is *target_folder_id*.""" + visited: set[str] = set() + to_check = list(parent_ids) + + for _ in range(max_depth): + if not to_check: + return False + + current = to_check.pop(0) + if current in visited: + continue + visited.add(current) + + if current == target_folder_id: + return True + + try: + service = await client.get_service() + meta = ( + service.files() + .get(fileId=current, fields="parents", supportsAllDrives=True) + .execute() + ) + grandparents = meta.get("parents", []) + to_check.extend(grandparents) + except Exception: + continue + + return False + + async def _filter_changes_by_folder( client: GoogleDriveClient, changes: list[dict[str, Any]], folder_id: str, ) -> list[dict[str, Any]]: - """ - Filter changes to only include files within the specified folder. - - Args: - client: GoogleDriveClient instance - changes: List of changes from API - folder_id: Folder ID to filter by - - Returns: - Filtered list of changes - """ + """Filter changes to only include files within the specified folder + (direct children or nested descendants).""" filtered = [] for change in changes: @@ -108,14 +136,10 @@ async def _filter_changes_by_folder( filtered.append(change) continue - # Check if file is in the folder (or subfolder) parents = file.get("parents", []) if folder_id in parents: filtered.append(change) - else: - # Check if any parent is a descendant of folder_id - # This is a simplified check - full implementation would traverse hierarchy - # For now, we'll include it and let indexer validate + elif await _is_descendant_of(client, parents, folder_id): filtered.append(change) return filtered From 489e48644fe88d1035a2e99d8d7c08e09193e419 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 22:15:24 +0530 Subject: [PATCH 24/31] fix: revert native excel parsing --- .../google_drive/content_extractor.py | 7 - .../document_processors/file_processors.py | 78 ----------- surfsense_backend/app/utils/office_parsers.py | 72 ---------- surfsense_backend/pyproject.toml | 1 - .../tests/unit/test_office_parsers.py | 129 ------------------ surfsense_backend/uv.lock | 2 - 6 files changed, 289 deletions(-) delete mode 100644 surfsense_backend/app/utils/office_parsers.py delete mode 100644 surfsense_backend/tests/unit/test_office_parsers.py diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 272a71403..de8e16156 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -14,8 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import Log from app.services.task_logging_service import TaskLoggingService -from app.utils.office_parsers import EXCEL_EXTENSIONS - from .client import GoogleDriveClient from .file_types import ( get_export_mime_type, @@ -150,11 +148,6 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: raise ValueError("Transcription returned empty text") return f"# Transcription of {filename}\n\n{text}" - if lower.endswith(EXCEL_EXTENSIONS): - from app.utils.office_parsers import parse_excel_to_markdown - - return await parse_excel_to_markdown(file_path, filename) - # Document files -- use configured ETL service from app.config import config as app_config diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index c69c6fa95..6c0ae1870 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -1134,59 +1134,6 @@ async def process_file_in_background( ) return None - elif filename.lower().endswith((".xlsx",)): - from app.utils.office_parsers import parse_excel_to_markdown - - if notification: - await ( - NotificationService.document_processing.notify_processing_progress( - session, - notification, - stage="parsing", - stage_message="Parsing spreadsheet", - ) - ) - - await task_logger.log_task_progress( - log_entry, - f"Processing Excel file natively: {filename}", - {"file_type": "excel", "processing_stage": "native_parse"}, - ) - - excel_markdown = await parse_excel_to_markdown(file_path, filename) - - try: - os.unlink(file_path) - except Exception as e: - print("Error deleting temp file", e) - - result = await add_received_markdown_file_document( - session, filename, excel_markdown, search_space_id, user_id, connector - ) - - if connector: - await _update_document_from_connector(result, connector, session) - - if result: - await task_logger.log_task_success( - log_entry, - f"Successfully parsed and processed Excel file: {filename}", - { - "document_id": result.id, - "content_hash": result.content_hash, - "file_type": "excel", - "etl_service": "NATIVE_EXCEL", - }, - ) - return result - else: - await task_logger.log_task_success( - log_entry, - f"Excel file already exists (duplicate): {filename}", - {"duplicate_detected": True, "file_type": "excel"}, - ) - return None - else: # Import page limit service from app.services.page_limit_service import ( @@ -1850,31 +1797,6 @@ async def process_file_in_background_with_document( with contextlib.suppress(Exception): os.unlink(file_path) - elif filename.lower().endswith((".xlsx",)): - from app.utils.office_parsers import parse_excel_to_markdown - - if notification: - await ( - NotificationService.document_processing.notify_processing_progress( - session, - notification, - stage="parsing", - stage_message="Parsing spreadsheet", - ) - ) - - await task_logger.log_task_progress( - log_entry, - f"Processing Excel file natively: {filename}", - {"file_type": "excel", "processing_stage": "native_parse"}, - ) - - markdown_content = await parse_excel_to_markdown(file_path, filename) - etl_service = "NATIVE_EXCEL" - - with contextlib.suppress(Exception): - os.unlink(file_path) - else: # Document files - use ETL service from app.services.page_limit_service import ( diff --git a/surfsense_backend/app/utils/office_parsers.py b/surfsense_backend/app/utils/office_parsers.py deleted file mode 100644 index a1550e110..000000000 --- a/surfsense_backend/app/utils/office_parsers.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Native parsers for Office file formats.""" - -import asyncio -import logging -import threading -import time -from pathlib import Path - -logger = logging.getLogger(__name__) - -EXCEL_EXTENSIONS = (".xlsx",) - - -def _parse_excel_sync(file_path: str) -> str: - """Parse an .xlsx file into markdown tables (synchronous).""" - from openpyxl import load_workbook - - wb = load_workbook(file_path, read_only=True, data_only=True) - markdown_parts: list[str] = [] - - for sheet_name in wb.sheetnames: - ws = wb[sheet_name] - rows = list(ws.iter_rows(values_only=True)) - non_empty_rows = [r for r in rows if any(c is not None for c in r)] - if not non_empty_rows: - continue - - markdown_parts.append(f"## {sheet_name}\n") - max_cols = max(len(row) for row in non_empty_rows) - - header = non_empty_rows[0] - hdr = [str(c if c is not None else "") for c in header] - hdr.extend([""] * (max_cols - len(hdr))) - markdown_parts.append("| " + " | ".join(hdr) + " |") - markdown_parts.append("| " + " | ".join("---" for _ in range(max_cols)) + " |") - - for row in non_empty_rows[1:]: - cells = [str(c if c is not None else "") for c in row] - cells.extend([""] * (max_cols - len(cells))) - markdown_parts.append("| " + " | ".join(cells) + " |") - - markdown_parts.append("") - - wb.close() - return "\n".join(markdown_parts) - - -async def parse_excel_to_markdown(file_path: str, filename: str = "") -> str: - """Parse an .xlsx file into markdown tables (async wrapper). - - Raises ``ValueError`` if no data is found in the workbook. - """ - t0 = time.monotonic() - logger.info( - "[excel-parse] START file=%s thread=%s", - filename, - threading.current_thread().name, - ) - - result = await asyncio.to_thread(_parse_excel_sync, file_path) - - logger.info( - "[excel-parse] END file=%s elapsed=%.2fs", - filename, - time.monotonic() - t0, - ) - - if not result.strip(): - raise ValueError(f"No data found in Excel file: {filename or file_path}") - - title = f"# {filename}\n\n" if filename else "" - return title + result diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 724e6db4c..017994c75 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -73,7 +73,6 @@ dependencies = [ "langchain-daytona>=0.0.2", "pypandoc>=1.16.2", "notion-markdown>=0.7.0", - "openpyxl>=3.1.5", ] [dependency-groups] diff --git a/surfsense_backend/tests/unit/test_office_parsers.py b/surfsense_backend/tests/unit/test_office_parsers.py deleted file mode 100644 index 11429a71d..000000000 --- a/surfsense_backend/tests/unit/test_office_parsers.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Unit tests for native Office file parsers (no DB, no external services).""" - -import tempfile - -import pytest -from openpyxl import Workbook - -pytestmark = pytest.mark.unit - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _create_xlsx(sheets: dict[str, list[list]]) -> str: - """Create a real .xlsx file on disk and return its path. - - ``sheets`` maps sheet name -> list of rows, where each row is a list of - cell values. - """ - wb = Workbook() - first = True - for name, rows in sheets.items(): - ws = wb.active if first else wb.create_sheet(title=name) - if first: - ws.title = name - first = False - for row in rows: - ws.append(row) - tmp = tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) - wb.save(tmp.name) - wb.close() - tmp.close() - return tmp.name - - -# --------------------------------------------------------------------------- -# Tracer bullet: cell values appear in markdown -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_parse_excel_produces_markdown_with_cell_values(): - """A single-sheet .xlsx with known data produces markdown containing those values.""" - from app.utils.office_parsers import parse_excel_to_markdown - - path = _create_xlsx( - {"Sales": [["Product", "Revenue"], ["Widget", 1500], ["Gadget", 3200]]} - ) - - md = await parse_excel_to_markdown(path, filename="report.xlsx") - - assert "Product" in md - assert "Revenue" in md - assert "Widget" in md - assert "1500" in md - assert "Gadget" in md - assert "3200" in md - assert "report.xlsx" in md - assert "|" in md - - -# --------------------------------------------------------------------------- -# Multi-sheet workbooks include all sheets -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_parse_excel_includes_all_sheets(): - """Both sheet names and their data appear in the output.""" - from app.utils.office_parsers import parse_excel_to_markdown - - path = _create_xlsx( - { - "Inventory": [["Item", "Qty"], ["Bolts", 200]], - "Pricing": [["Item", "Price"], ["Bolts", 4.50]], - } - ) - - md = await parse_excel_to_markdown(path, filename="multi.xlsx") - - assert "Inventory" in md - assert "Pricing" in md - assert "Bolts" in md - assert "200" in md - assert "4.5" in md - - -# --------------------------------------------------------------------------- -# Empty spreadsheet raises ValueError -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_parse_excel_raises_on_empty_file(): - """An .xlsx with no data raises ValueError.""" - from app.utils.office_parsers import parse_excel_to_markdown - - wb = Workbook() - tmp = tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) - wb.save(tmp.name) - wb.close() - tmp.close() - - with pytest.raises(ValueError, match="No data found"): - await parse_excel_to_markdown(tmp.name, filename="empty.xlsx") - - -# --------------------------------------------------------------------------- -# _parse_file_to_markdown routes .xlsx natively (no ETL call) -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_parse_file_to_markdown_routes_xlsx_natively(): - """content_extractor._parse_file_to_markdown uses native parser for .xlsx.""" - from app.connectors.google_drive.content_extractor import _parse_file_to_markdown - - path = _create_xlsx( - {"Data": [["Name", "Score"], ["Alice", 95], ["Bob", 82]]} - ) - - md = await _parse_file_to_markdown(path, "grades.xlsx") - - assert "Alice" in md - assert "95" in md - assert "Bob" in md - assert "82" in md diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index e4d148b50..82ae4cc16 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -7919,7 +7919,6 @@ dependencies = [ { name = "notion-client" }, { name = "notion-markdown" }, { name = "numpy" }, - { name = "openpyxl" }, { name = "pgvector" }, { name = "playwright" }, { name = "psycopg", extra = ["binary", "pool"] }, @@ -8002,7 +8001,6 @@ requires-dist = [ { name = "notion-client", specifier = ">=2.3.0" }, { name = "notion-markdown", specifier = ">=0.7.0" }, { name = "numpy", specifier = ">=1.24.0" }, - { name = "openpyxl", specifier = ">=3.1.5" }, { name = "pgvector", specifier = ">=0.3.6" }, { name = "playwright", specifier = ">=1.50.0" }, { name = "psycopg", extras = ["binary", "pool"], specifier = ">=3.3.2" }, From 6d4eb323451491ae2d97cb7c23960256ffbbada0 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 22:20:32 +0530 Subject: [PATCH 25/31] fix: update export format for Google Docs to use correct MIME type --- surfsense_backend/app/connectors/google_drive/file_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surfsense_backend/app/connectors/google_drive/file_types.py b/surfsense_backend/app/connectors/google_drive/file_types.py index dd3456901..dd6aff4d7 100644 --- a/surfsense_backend/app/connectors/google_drive/file_types.py +++ b/surfsense_backend/app/connectors/google_drive/file_types.py @@ -7,7 +7,7 @@ GOOGLE_FOLDER = "application/vnd.google-apps.folder" GOOGLE_SHORTCUT = "application/vnd.google-apps.shortcut" EXPORT_FORMATS = { - GOOGLE_DOC: "application/pdf", + GOOGLE_DOC: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", GOOGLE_SHEET: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", GOOGLE_SLIDE: "application/pdf", } From ddccba0df8bd6c59b0cae209c527bc16dec1d0e2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 23:11:13 +0530 Subject: [PATCH 26/31] refactor: improve UI components for folder and document management --- .../documents/CreateFolderDialog.tsx | 46 ++++++++++++------- .../components/documents/DocumentNode.tsx | 6 +-- .../components/documents/FolderNode.tsx | 4 +- .../documents/FolderPickerDialog.tsx | 28 ++++++++--- .../layout/ui/sidebar/DocumentsSidebar.tsx | 2 +- 5 files changed, 56 insertions(+), 30 deletions(-) diff --git a/surfsense_web/components/documents/CreateFolderDialog.tsx b/surfsense_web/components/documents/CreateFolderDialog.tsx index 992c5d24c..8c2643828 100644 --- a/surfsense_web/components/documents/CreateFolderDialog.tsx +++ b/surfsense_web/components/documents/CreateFolderDialog.tsx @@ -1,6 +1,5 @@ "use client"; -import { FolderPlus } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; import { Button } from "@/components/ui/button"; import { @@ -52,22 +51,25 @@ export function CreateFolderDialog({ return ( - - - - - {isSubfolder ? "New subfolder" : "New folder"} - - - {isSubfolder - ? `Create a new folder inside "${parentFolderName}".` - : "Create a new folder at the root level."} - + + +
+
+ + {isSubfolder ? "New subfolder" : "New folder"} + + + {isSubfolder + ? `Create a new folder inside "${parentFolderName}".` + : "Create a new folder at the root level."} + +
+
-
+
- + setName(e.target.value)} maxLength={255} autoComplete="off" + className="text-sm h-9 sm:h-10" />
- - - diff --git a/surfsense_web/components/documents/DocumentNode.tsx b/surfsense_web/components/documents/DocumentNode.tsx index 4c156d830..20a40cc34 100644 --- a/surfsense_web/components/documents/DocumentNode.tsx +++ b/surfsense_web/components/documents/DocumentNode.tsx @@ -84,7 +84,7 @@ export const DocumentNode = React.memo(function DocumentNode({ role="button" tabIndex={0} className={cn( - "group flex h-8 items-center gap-1.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none", + "group flex h-8 items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none", isMentioned && "bg-accent/30", isDragging && "opacity-40" )} @@ -137,7 +137,7 @@ export const DocumentNode = React.memo(function DocumentNode({ - + onPreview(doc)}> Open @@ -166,7 +166,7 @@ export const DocumentNode = React.memo(function DocumentNode({
- + onPreview(doc)}> Open diff --git a/surfsense_web/components/documents/FolderNode.tsx b/surfsense_web/components/documents/FolderNode.tsx index 03dad83e7..9dc3b4034 100644 --- a/surfsense_web/components/documents/FolderNode.tsx +++ b/surfsense_web/components/documents/FolderNode.tsx @@ -285,7 +285,7 @@ export const FolderNode = React.memo(function FolderNode({ - + { e.stopPropagation(); @@ -331,7 +331,7 @@ export const FolderNode = React.memo(function FolderNode({ {!isRenaming && ( - + onCreateSubfolder(folder.id)}> New subfolder diff --git a/surfsense_web/components/documents/FolderPickerDialog.tsx b/surfsense_web/components/documents/FolderPickerDialog.tsx index 366db1eb9..3c866e04a 100644 --- a/surfsense_web/components/documents/FolderPickerDialog.tsx +++ b/surfsense_web/components/documents/FolderPickerDialog.tsx @@ -124,10 +124,18 @@ export function FolderPickerDialog({ return ( - - - {title} - {description && {description}} + + +
+
+ {title} + {description && ( + + {description} + + )} +
+
@@ -147,11 +155,17 @@ export function FolderPickerDialog({ {renderPickerLevel(null, 1)}
- - - +
diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index a433b4e3c..832a03942 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -617,7 +617,7 @@ export function DocumentsSidebar({ open={folderPickerOpen} onOpenChange={setFolderPickerOpen} folders={treeFolders} - title={folderPickerTarget?.type === "folder" ? "Move folder to..." : "Move document to..."} + title={folderPickerTarget?.type === "folder" ? "Move folder to" : "Move document to"} description="Select a destination folder, or choose Root to move to the top level." disabledFolderIds={folderPickerTarget?.disabledIds} onSelect={handleFolderPickerSelect} From 13f4b175a68c4bd03a99eaee342617f2b4072a34 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 23:14:10 +0530 Subject: [PATCH 27/31] feat: enhance context menu functionality in DocumentNode and FolderNode components --- .../components/documents/DocumentNode.tsx | 60 ++++++++++--------- .../components/documents/FolderNode.tsx | 12 ++-- .../components/documents/FolderTreeView.tsx | 8 ++- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/surfsense_web/components/documents/DocumentNode.tsx b/surfsense_web/components/documents/DocumentNode.tsx index 20a40cc34..e56bcdbb7 100644 --- a/surfsense_web/components/documents/DocumentNode.tsx +++ b/surfsense_web/components/documents/DocumentNode.tsx @@ -10,14 +10,12 @@ import { ContextMenu, ContextMenuContent, ContextMenuItem, - ContextMenuSeparator, ContextMenuTrigger, } from "@/components/ui/context-menu"; import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, - DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import type { DocumentTypeEnum } from "@/contracts/types/document.types"; @@ -41,6 +39,8 @@ interface DocumentNodeProps { onEdit: (doc: DocumentNodeDoc) => void; onDelete: (doc: DocumentNodeDoc) => void; onMove: (doc: DocumentNodeDoc) => void; + contextMenuOpen?: boolean; + onContextMenuOpenChange?: (open: boolean) => void; } export const DocumentNode = React.memo(function DocumentNode({ @@ -52,6 +52,8 @@ export const DocumentNode = React.memo(function DocumentNode({ onEdit, onDelete, onMove, + contextMenuOpen, + onContextMenuOpenChange, }: DocumentNodeProps) { const statusState = doc.status?.state ?? "ready"; const isSelectable = statusState !== "pending" && statusState !== "processing"; @@ -76,7 +78,7 @@ export const DocumentNode = React.memo(function DocumentNode({ const isProcessing = statusState === "pending" || statusState === "processing"; return ( - + {/* biome-ignore lint/a11y/useSemanticElements: div required for drag ref */}
e.stopPropagation()} > @@ -152,7 +154,6 @@ export const DocumentNode = React.memo(function DocumentNode({ Move to... - - - onPreview(doc)}> - - Open - - {isEditable && ( - onEdit(doc)}> - - Edit + {contextMenuOpen && ( + + onPreview(doc)}> + + Open - )} - onMove(doc)}> - - Move to... - - - onDelete(doc)} - > - - Delete - - + {isEditable && ( + onEdit(doc)}> + + Edit + + )} + onMove(doc)}> + + Move to... + + onDelete(doc)} + > + + Delete + + + )} ); }); diff --git a/surfsense_web/components/documents/FolderNode.tsx b/surfsense_web/components/documents/FolderNode.tsx index 9dc3b4034..edc685d99 100644 --- a/surfsense_web/components/documents/FolderNode.tsx +++ b/surfsense_web/components/documents/FolderNode.tsx @@ -66,6 +66,8 @@ interface FolderNodeProps { onReorderFolder?: (folderId: number, beforePos: string | null, afterPos: string | null) => void; siblingPositions?: { before: string | null; after: string | null }; disabledDropIds?: Set; + contextMenuOpen?: boolean; + onContextMenuOpenChange?: (open: boolean) => void; } function getDropZone( @@ -99,6 +101,8 @@ export const FolderNode = React.memo(function FolderNode({ onReorderFolder, siblingPositions, disabledDropIds, + contextMenuOpen, + onContextMenuOpenChange, }: FolderNodeProps) { const [renameValue, setRenameValue] = useState(folder.name); const inputRef = useRef(null); @@ -213,7 +217,7 @@ export const FolderNode = React.memo(function FolderNode({ const FolderIcon = isExpanded ? FolderOpen : Folder; return ( - + {/* biome-ignore lint/a11y/useSemanticElements: div required for drag/drop refs */}
e.stopPropagation()} > @@ -313,7 +317,6 @@ export const FolderNode = React.memo(function FolderNode({ Move to... - { @@ -330,7 +333,7 @@ export const FolderNode = React.memo(function FolderNode({
- {!isRenaming && ( + {!isRenaming && contextMenuOpen && ( onCreateSubfolder(folder.id)}> @@ -344,7 +347,6 @@ export const FolderNode = React.memo(function FolderNode({ Move to... - onDelete(folder)} diff --git a/surfsense_web/components/documents/FolderTreeView.tsx b/surfsense_web/components/documents/FolderTreeView.tsx index ca64ab1e0..4c41db0ed 100644 --- a/surfsense_web/components/documents/FolderTreeView.tsx +++ b/surfsense_web/components/documents/FolderTreeView.tsx @@ -2,7 +2,7 @@ import { useAtom } from "jotai"; import { TreePine } from "lucide-react"; -import { useCallback, useMemo } from "react"; +import { useCallback, useMemo, useState } from "react"; import { DndProvider } from "react-dnd"; import { HTML5Backend } from "react-dnd-html5-backend"; import { renamingFolderIdAtom } from "@/atoms/documents/folder.atoms"; @@ -80,6 +80,8 @@ export function FolderTreeView({ return counts; }, [folders, foldersByParent, docsByFolder]); + const [openContextMenuId, setOpenContextMenuId] = useState(null); + // Single subscription for rename state — derived boolean passed to each FolderNode const [renamingFolderId, setRenamingFolderId] = useAtom(renamingFolderIdAtom); const handleStartRename = useCallback( @@ -157,6 +159,8 @@ export function FolderTreeView({ onDropIntoFolder={onDropIntoFolder} onReorderFolder={onReorderFolder} siblingPositions={siblingPositions} + contextMenuOpen={openContextMenuId === `folder-${f.id}`} + onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `folder-${f.id}` : null)} /> ); @@ -177,6 +181,8 @@ export function FolderTreeView({ onEdit={onEditDocument} onDelete={onDeleteDocument} onMove={onMoveDocument} + contextMenuOpen={openContextMenuId === `doc-${d.id}`} + onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `doc-${d.id}` : null)} /> ); } From 0204ed53635c6dee2e7a418dc2e16787209f570e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 27 Mar 2026 23:26:12 +0530 Subject: [PATCH 28/31] refactor: replace Pencil icon with PenLine and update TreePine to CirclePlus in document components --- surfsense_web/components/documents/DocumentNode.tsx | 6 +++--- surfsense_web/components/documents/FolderNode.tsx | 11 +++++------ surfsense_web/components/documents/FolderTreeView.tsx | 6 +++--- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/surfsense_web/components/documents/DocumentNode.tsx b/surfsense_web/components/documents/DocumentNode.tsx index e56bcdbb7..315d4b29a 100644 --- a/surfsense_web/components/documents/DocumentNode.tsx +++ b/surfsense_web/components/documents/DocumentNode.tsx @@ -1,6 +1,6 @@ "use client"; -import { Eye, MoreHorizontal, Move, Pencil, Trash2 } from "lucide-react"; +import { Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react"; import React, { useCallback } from "react"; import { useDrag } from "react-dnd"; import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon"; @@ -146,7 +146,7 @@ export const DocumentNode = React.memo(function DocumentNode({
{isEditable && ( onEdit(doc)}> - + Edit )} @@ -175,7 +175,7 @@ export const DocumentNode = React.memo(function DocumentNode({ {isEditable && ( onEdit(doc)}> - + Edit )} diff --git a/surfsense_web/components/documents/FolderNode.tsx b/surfsense_web/components/documents/FolderNode.tsx index edc685d99..cb314effb 100644 --- a/surfsense_web/components/documents/FolderNode.tsx +++ b/surfsense_web/components/documents/FolderNode.tsx @@ -8,7 +8,7 @@ import { FolderPlus, MoreHorizontal, Move, - Pencil, + PenLine, Trash2, } from "lucide-react"; import React, { useCallback, useEffect, useRef, useState } from "react"; @@ -18,14 +18,12 @@ import { ContextMenu, ContextMenuContent, ContextMenuItem, - ContextMenuSeparator, ContextMenuTrigger, } from "@/components/ui/context-menu"; import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, - DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { cn } from "@/lib/utils"; @@ -265,7 +263,8 @@ export const FolderNode = React.memo(function FolderNode({ onBlur={handleRenameSubmit} onKeyDown={handleRenameKeyDown} onClick={(e) => e.stopPropagation()} - className="flex-1 min-w-0 rounded border border-primary bg-background px-1 py-0.5 text-sm outline-none" + placeholder="Enter folder name" + className="flex-1 min-w-0 bg-transparent px-1 py-0.5 text-sm outline-none caret-primary placeholder:text-muted-foreground/50" /> ) : ( {folder.name} @@ -305,7 +304,7 @@ export const FolderNode = React.memo(function FolderNode({ startRename(); }} > - + Rename startRename()}> - + Rename onMove(folder)}> diff --git a/surfsense_web/components/documents/FolderTreeView.tsx b/surfsense_web/components/documents/FolderTreeView.tsx index 4c41db0ed..287afa612 100644 --- a/surfsense_web/components/documents/FolderTreeView.tsx +++ b/surfsense_web/components/documents/FolderTreeView.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtom } from "jotai"; -import { TreePine } from "lucide-react"; +import { CirclePlus } from "lucide-react"; import { useCallback, useMemo, useState } from "react"; import { DndProvider } from "react-dnd"; import { HTML5Backend } from "react-dnd-html5-backend"; @@ -195,7 +195,7 @@ export function FolderTreeView({ if (treeNodes.length === 0 && folders.length === 0 && documents.length === 0) { return (
- +

No documents yet

); @@ -204,7 +204,7 @@ export function FolderTreeView({ if (treeNodes.length === 0 && activeTypes.length > 0) { return (
- +

No matching documents

); From 96549791e68fd068e8b8ee2e67d3d74d98f285b9 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 28 Mar 2026 00:11:32 +0530 Subject: [PATCH 29/31] feat: enhance DocumentNode component with loading and error indicators --- .../components/documents/DocumentNode.tsx | 88 +++++++++++++------ 1 file changed, 60 insertions(+), 28 deletions(-) diff --git a/surfsense_web/components/documents/DocumentNode.tsx b/surfsense_web/components/documents/DocumentNode.tsx index 315d4b29a..e55512e96 100644 --- a/surfsense_web/components/documents/DocumentNode.tsx +++ b/surfsense_web/components/documents/DocumentNode.tsx @@ -1,7 +1,7 @@ "use client"; -import { Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react"; -import React, { useCallback } from "react"; +import { AlertCircle, Clock, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react"; +import React, { useCallback, useRef } from "react"; import { useDrag } from "react-dnd"; import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon"; import { Button } from "@/components/ui/button"; @@ -18,6 +18,8 @@ import { DropdownMenuItem, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; +import { Spinner } from "@/components/ui/spinner"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { cn } from "@/lib/utils"; import { DND_TYPES } from "./FolderNode"; @@ -76,48 +78,78 @@ export const DocumentNode = React.memo(function DocumentNode({ ); const isProcessing = statusState === "pending" || statusState === "processing"; + const rowRef = useRef(null); + + const attachRef = useCallback( + (node: HTMLButtonElement | null) => { + (rowRef as React.MutableRefObject).current = node; + drag(node); + }, + [drag] + ); return ( - {/* biome-ignore lint/a11y/useSemanticElements: div required for drag ref */} -
{ - if (e.key === "Enter" || e.key === " ") { - e.preventDefault(); - handleCheckChange(); - } - }} > - {isSelectable ? ( + {(() => { + if (statusState === "pending") { + return ( + + + + + + + Pending - waiting to be synced + + ); + } + if (statusState === "processing") { + return ( + + + + + + + Syncing + + ); + } + if (statusState === "failed") { + return ( + + + + + + + + {doc.status?.reason || "Processing failed"} + + + ); + } + return ( e.stopPropagation()} className="h-3.5 w-3.5 shrink-0" /> - ) : ( - - - - )} + ); + })()} {doc.title} @@ -164,7 +196,7 @@ export const DocumentNode = React.memo(function DocumentNode({ -
+
{contextMenuOpen && ( From 0aa9cd6dfc3f556d74c4a81b77acc6133eb0fc33 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 28 Mar 2026 02:32:03 +0530 Subject: [PATCH 30/31] feat: implement dropdown menu state management in DocumentNode and sidebar components --- .../components/documents/DocumentNode.tsx | 28 +++++++++++-------- .../ui/sidebar/AllPrivateChatsSidebar.tsx | 7 +++-- .../ui/sidebar/AllSharedChatsSidebar.tsx | 7 +++-- .../layout/ui/sidebar/ChatListItem.tsx | 11 ++++++-- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/surfsense_web/components/documents/DocumentNode.tsx b/surfsense_web/components/documents/DocumentNode.tsx index e55512e96..4acdcf662 100644 --- a/surfsense_web/components/documents/DocumentNode.tsx +++ b/surfsense_web/components/documents/DocumentNode.tsx @@ -1,7 +1,7 @@ "use client"; import { AlertCircle, Clock, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react"; -import React, { useCallback, useRef } from "react"; +import React, { useCallback, useRef, useState } from "react"; import { useDrag } from "react-dnd"; import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon"; import { Button } from "@/components/ui/button"; @@ -78,6 +78,7 @@ export const DocumentNode = React.memo(function DocumentNode({ ); const isProcessing = statusState === "pending" || statusState === "processing"; + const [dropdownOpen, setDropdownOpen] = useState(false); const rowRef = useRef(null); const attachRef = useCallback( @@ -95,9 +96,9 @@ export const DocumentNode = React.memo(function DocumentNode({ type="button" ref={attachRef} className={cn( - "group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left", - isMentioned && "bg-accent/30", - isDragging && "opacity-40" + "group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left", + isMentioned && "bg-accent/30", + isDragging && "opacity-40" )} style={{ paddingLeft: `${depth * 16 + 4}px` }} onClick={handleCheckChange} @@ -160,14 +161,17 @@ export const DocumentNode = React.memo(function DocumentNode({ )} - - - diff --git a/surfsense_web/components/layout/ui/sidebar/AllPrivateChatsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/AllPrivateChatsSidebar.tsx index fc5fb3873..65a24208b 100644 --- a/surfsense_web/components/layout/ui/sidebar/AllPrivateChatsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/AllPrivateChatsSidebar.tsx @@ -396,10 +396,13 @@ export function AllPrivateChatsSidebarContent({ variant="ghost" size="icon" className={cn( - "h-6 w-6 shrink-0", + "h-6 w-6 shrink-0 hover:bg-transparent", isMobile ? "opacity-0 pointer-events-none absolute" - : "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100", + : openDropdownId === thread.id + ? "opacity-100" + : "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100", + openDropdownId === thread.id && "bg-accent hover:bg-accent", "transition-opacity" )} disabled={isBusy} diff --git a/surfsense_web/components/layout/ui/sidebar/AllSharedChatsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/AllSharedChatsSidebar.tsx index 03673955a..ce0e45e81 100644 --- a/surfsense_web/components/layout/ui/sidebar/AllSharedChatsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/AllSharedChatsSidebar.tsx @@ -396,10 +396,13 @@ export function AllSharedChatsSidebarContent({ variant="ghost" size="icon" className={cn( - "h-6 w-6 shrink-0", + "h-6 w-6 shrink-0 hover:bg-transparent", isMobile ? "opacity-0 pointer-events-none absolute" - : "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100", + : openDropdownId === thread.id + ? "opacity-100" + : "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100", + openDropdownId === thread.id && "bg-accent hover:bg-accent", "transition-opacity" )} disabled={isBusy} diff --git a/surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx b/surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx index fc0520745..7f3089a89 100644 --- a/surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx +++ b/surfsense_web/components/layout/ui/sidebar/ChatListItem.tsx @@ -79,14 +79,21 @@ export function ChatListItem({ : "bg-gradient-to-l from-sidebar from-60% to-transparent group-hover/item:from-accent", isMobile ? "opacity-0" - : isActive + : isActive || dropdownOpen ? "opacity-100" : "opacity-0 group-hover/item:opacity-100" )} > - From b5ef7afb1cbeee7500c9119efcd90a1abe9758a2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 28 Mar 2026 02:58:38 +0530 Subject: [PATCH 31/31] feat: add multi-format document export functionality to editor routes and UI components - Implemented a new export endpoint in the backend to support exporting documents in various formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text). - Enhanced DocumentNode and FolderTreeView components to include export options in context and dropdown menus. - Created shared ExportMenuItems component for consistent export options across the application. - Integrated loading indicators for export actions to improve user experience. --- surfsense_backend/app/routes/editor_routes.py | 176 +++++++++++++++++- .../components/documents/DocumentNode.tsx | 46 ++++- .../components/documents/FolderTreeView.tsx | 3 + .../layout/ui/sidebar/DocumentsSidebar.tsx | 40 ++++ .../components/report-panel/report-panel.tsx | 85 +-------- .../components/shared/ExportMenuItems.tsx | 142 ++++++++++++++ 6 files changed, 411 insertions(+), 81 deletions(-) create mode 100644 surfsense_web/components/shared/ExportMenuItems.tsx diff --git a/surfsense_backend/app/routes/editor_routes.py b/surfsense_backend/app/routes/editor_routes.py index 84846ef38..b7bbd5abb 100644 --- a/surfsense_backend/app/routes/editor_routes.py +++ b/surfsense_backend/app/routes/editor_routes.py @@ -1,19 +1,43 @@ """ Editor routes for document editing with markdown (Plate.js frontend). +Includes multi-format export (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text). """ +import asyncio +import io +import logging +import os +import re +import tempfile from datetime import UTC, datetime from typing import Any -from fastapi import APIRouter, Depends, HTTPException +import pypandoc +import typst +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.db import Document, DocumentType, Permission, User, get_async_session +from app.routes.reports_routes import ( + ExportFormat, + _FILE_EXTENSIONS, + _MEDIA_TYPES, + _normalize_latex_delimiters, + _strip_wrapping_code_fences, +) +from app.templates.export_helpers import ( + get_html_css_path, + get_reference_docx_path, + get_typst_template_path, +) from app.users import current_active_user from app.utils.rbac import check_permission +logger = logging.getLogger(__name__) + router = APIRouter() @@ -212,3 +236,153 @@ async def save_document( "message": "Document saved and will be reindexed in the background", "updated_at": document.updated_at.isoformat(), } + + +@router.get( + "/search-spaces/{search_space_id}/documents/{document_id}/export" +) +async def export_document( + search_space_id: int, + document_id: int, + format: ExportFormat = Query( + ExportFormat.PDF, + description="Export format: pdf, docx, html, latex, epub, odt, or plain", + ), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Export a document in the requested format (reuses the report export pipeline).""" + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + + result = await session.execute( + select(Document) + .options(selectinload(Document.chunks)) + .filter( + Document.id == document_id, + Document.search_space_id == search_space_id, + ) + ) + document = result.scalars().first() + if not document: + raise HTTPException(status_code=404, detail="Document not found") + + # Resolve markdown content (same priority as editor-content endpoint) + markdown_content: str | None = document.source_markdown + if markdown_content is None and document.blocknote_document: + from app.utils.blocknote_to_markdown import blocknote_to_markdown + + markdown_content = blocknote_to_markdown(document.blocknote_document) + if markdown_content is None: + chunks = sorted(document.chunks, key=lambda c: c.id) + if chunks: + markdown_content = "\n\n".join(chunk.content for chunk in chunks) + + if not markdown_content or not markdown_content.strip(): + raise HTTPException( + status_code=400, detail="Document has no content to export" + ) + + markdown_content = _strip_wrapping_code_fences(markdown_content) + markdown_content = _normalize_latex_delimiters(markdown_content) + + doc_title = document.title or "Document" + formatted_date = ( + document.created_at.strftime("%B %d, %Y") if document.created_at else "" + ) + input_fmt = "gfm+tex_math_dollars" + meta_args = ["-M", f"title:{doc_title}", "-M", f"date:{formatted_date}"] + + def _convert_and_read() -> bytes: + if format == ExportFormat.PDF: + typst_template = str(get_typst_template_path()) + typst_markup: str = pypandoc.convert_text( + markdown_content, + "typst", + format=input_fmt, + extra_args=[ + "--standalone", + f"--template={typst_template}", + "-V", "mainfont:Libertinus Serif", + "-V", "codefont:DejaVu Sans Mono", + *meta_args, + ], + ) + return typst.compile(typst_markup.encode("utf-8")) + + if format == ExportFormat.DOCX: + return _pandoc_to_tempfile( + format.value, + ["--standalone", f"--reference-doc={get_reference_docx_path()}", *meta_args], + ) + + if format == ExportFormat.HTML: + html_str: str = pypandoc.convert_text( + markdown_content, + "html5", + format=input_fmt, + extra_args=[ + "--standalone", "--embed-resources", + f"--css={get_html_css_path()}", + "--syntax-highlighting=pygments", + *meta_args, + ], + ) + return html_str.encode("utf-8") + + if format == ExportFormat.EPUB: + return _pandoc_to_tempfile("epub3", ["--standalone", *meta_args]) + + if format == ExportFormat.ODT: + return _pandoc_to_tempfile("odt", ["--standalone", *meta_args]) + + if format == ExportFormat.LATEX: + tex_str: str = pypandoc.convert_text( + markdown_content, "latex", format=input_fmt, + extra_args=["--standalone", *meta_args], + ) + return tex_str.encode("utf-8") + + plain_str: str = pypandoc.convert_text( + markdown_content, "plain", format=input_fmt, + extra_args=["--wrap=auto", "--columns=80"], + ) + return plain_str.encode("utf-8") + + def _pandoc_to_tempfile(output_format: str, extra_args: list[str]) -> bytes: + fd, tmp_path = tempfile.mkstemp(suffix=f".{output_format}") + os.close(fd) + try: + pypandoc.convert_text( + markdown_content, output_format, format=input_fmt, + extra_args=extra_args, outputfile=tmp_path, + ) + with open(tmp_path, "rb") as f: + return f.read() + finally: + os.unlink(tmp_path) + + try: + loop = asyncio.get_running_loop() + output = await loop.run_in_executor(None, _convert_and_read) + except Exception as e: + logger.exception("Document export failed") + raise HTTPException(status_code=500, detail=f"Export failed: {e!s}") from e + + safe_title = ( + "".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title) + .strip()[:80] + or "document" + ) + ext = _FILE_EXTENSIONS[format] + + return StreamingResponse( + io.BytesIO(output), + media_type=_MEDIA_TYPES[format], + headers={"Content-Disposition": f'attachment; filename="{safe_title}.{ext}"'}, + ) diff --git a/surfsense_web/components/documents/DocumentNode.tsx b/surfsense_web/components/documents/DocumentNode.tsx index 4acdcf662..57a12ab3a 100644 --- a/surfsense_web/components/documents/DocumentNode.tsx +++ b/surfsense_web/components/documents/DocumentNode.tsx @@ -1,21 +1,28 @@ "use client"; -import { AlertCircle, Clock, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react"; +import { AlertCircle, Clock, Download, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react"; import React, { useCallback, useRef, useState } from "react"; import { useDrag } from "react-dnd"; import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon"; +import { ExportContextItems, ExportDropdownItems } from "@/components/shared/ExportMenuItems"; import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; import { ContextMenu, ContextMenuContent, ContextMenuItem, + ContextMenuSub, + ContextMenuSubContent, + ContextMenuSubTrigger, ContextMenuTrigger, } from "@/components/ui/context-menu"; import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { Spinner } from "@/components/ui/spinner"; @@ -41,6 +48,7 @@ interface DocumentNodeProps { onEdit: (doc: DocumentNodeDoc) => void; onDelete: (doc: DocumentNodeDoc) => void; onMove: (doc: DocumentNodeDoc) => void; + onExport?: (doc: DocumentNodeDoc, format: string) => void; contextMenuOpen?: boolean; onContextMenuOpenChange?: (open: boolean) => void; } @@ -54,6 +62,7 @@ export const DocumentNode = React.memo(function DocumentNode({ onEdit, onDelete, onMove, + onExport, contextMenuOpen, onContextMenuOpenChange, }: DocumentNodeProps) { @@ -79,8 +88,19 @@ export const DocumentNode = React.memo(function DocumentNode({ const isProcessing = statusState === "pending" || statusState === "processing"; const [dropdownOpen, setDropdownOpen] = useState(false); + const [exporting, setExporting] = useState(null); const rowRef = useRef(null); + const handleExport = useCallback( + (format: string) => { + if (!onExport) return; + setExporting(format); + onExport(doc, format); + setTimeout(() => setExporting(null), 2000); + }, + [doc, onExport] + ); + const attachRef = useCallback( (node: HTMLButtonElement | null) => { (rowRef as React.MutableRefObject).current = node; @@ -167,7 +187,7 @@ export const DocumentNode = React.memo(function DocumentNode({ variant="ghost" size="icon" className={cn( - "hidden sm:inline-flex h-6 w-6 shrink-0 transition-opacity hover:bg-transparent", + "hidden sm:inline-flex h-6 w-6 shrink-0 hover:bg-transparent", dropdownOpen ? "opacity-100 bg-accent hover:bg-accent" : "opacity-0 group-hover:opacity-100" )} onClick={(e) => e.stopPropagation()} @@ -190,6 +210,17 @@ export const DocumentNode = React.memo(function DocumentNode({ Move to...
+ {onExport && ( + + + + Export + + + + + + )} Move to... + {onExport && ( + + + + Export + + + + + + )} void; onDeleteDocument: (doc: DocumentNodeDoc) => void; onMoveDocument: (doc: DocumentNodeDoc) => void; + onExportDocument?: (doc: DocumentNodeDoc, format: string) => void; activeTypes: DocumentTypeEnum[]; onDropIntoFolder?: ( itemType: "folder" | "document", @@ -62,6 +63,7 @@ export function FolderTreeView({ onEditDocument, onDeleteDocument, onMoveDocument, + onExportDocument, activeTypes, onDropIntoFolder, onReorderFolder, @@ -181,6 +183,7 @@ export function FolderTreeView({ onEdit={onEditDocument} onDelete={onDeleteDocument} onMove={onMoveDocument} + onExport={onExportDocument} contextMenuOpen={openContextMenuId === `doc-${d.id}`} onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `doc-${d.id}` : null)} /> diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index 832a03942..1e1e8f982 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -7,6 +7,7 @@ import { useParams } from "next/navigation"; import { useTranslations } from "next-intl"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; +import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems"; import { DocumentsFilters } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsFilters"; import { DocumentsTableShell, @@ -33,6 +34,7 @@ import { useDocumentSearch } from "@/hooks/use-document-search"; import { useDocuments } from "@/hooks/use-documents"; import { useMediaQuery } from "@/hooks/use-media-query"; import { foldersApiService } from "@/lib/apis/folders-api.service"; +import { authenticatedFetch } from "@/lib/auth-utils"; import { queries } from "@/zero/queries/index"; import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel"; @@ -234,6 +236,43 @@ export function DocumentsSidebar({ setFolderPickerOpen(true); }, []); + const handleExportDocument = useCallback( + async (doc: DocumentNodeDoc, format: string) => { + const safeTitle = + doc.title + .replace(/[^a-zA-Z0-9 _-]/g, "_") + .trim() + .slice(0, 80) || "document"; + const ext = EXPORT_FILE_EXTENSIONS[format] ?? format; + + try { + const response = await authenticatedFetch( + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${doc.id}/export?format=${format}`, + { method: "GET" } + ); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({ detail: "Export failed" })); + throw new Error(errorData.detail || "Export failed"); + } + + const blob = await response.blob(); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = `${safeTitle}.${ext}`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } catch (err) { + console.error(`Export ${format} failed:`, err); + toast.error(err instanceof Error ? err.message : `Export failed`); + } + }, + [searchSpaceId] + ); + const handleFolderPickerSelect = useCallback( async (targetFolderId: number | null) => { if (!folderPickerTarget) return; @@ -606,6 +645,7 @@ export function DocumentsSidebar({ }} onDeleteDocument={(doc) => handleDeleteDocument(doc.id)} onMoveDocument={handleMoveDocument} + onExportDocument={handleExportDocument} activeTypes={activeTypes} onDropIntoFolder={handleDropIntoFolder} onReorderFolder={handleReorderFolder} diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index 62c89831d..f7d256a95 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -15,10 +15,9 @@ import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, - DropdownMenuLabel, - DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; +import { ExportDropdownItems, EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems"; import { useMediaQuery } from "@/hooks/use-media-query"; import { baseApiService } from "@/lib/apis/base-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; @@ -198,19 +197,6 @@ export function ReportPanelContent({ } }, [currentMarkdown]); - // Maps backend format values to download file extensions - const FILE_EXTENSIONS: Record = { - pdf: "pdf", - docx: "docx", - html: "html", - latex: "tex", - epub: "epub", - odt: "odt", - plain: "txt", - md: "md", - }; - - // Export report const handleExport = useCallback( async (format: string) => { setExporting(format); @@ -219,7 +205,7 @@ export function ReportPanelContent({ .replace(/[^a-zA-Z0-9 _-]/g, "_") .trim() .slice(0, 80) || "report"; - const ext = FILE_EXTENSIONS[format] ?? format; + const ext = EXPORT_FILE_EXTENSIONS[format] ?? format; try { if (format === "md") { if (!currentMarkdown) return; @@ -329,68 +315,11 @@ export function ReportPanelContent({ align="start" className={`min-w-[200px] select-none${insideDrawer ? " z-[100]" : ""}`} > - {!shareToken && ( - <> - - Documents - - handleExport("pdf")} - disabled={exporting !== null} - > - PDF (.pdf) - - handleExport("docx")} - disabled={exporting !== null} - > - Word (.docx) - - handleExport("odt")} - disabled={exporting !== null} - > - OpenDocument (.odt) - - - - Web & E-Book - - handleExport("html")} - disabled={exporting !== null} - > - HTML (.html) - - handleExport("epub")} - disabled={exporting !== null} - > - EPUB (.epub) - - - - Source & Plain - - handleExport("latex")} - disabled={exporting !== null} - > - LaTeX (.tex) - - - )} - handleExport("md")} disabled={exporting !== null}> - Markdown (.md) - - {!shareToken && ( - handleExport("plain")} - disabled={exporting !== null} - > - Plain Text (.txt) - - )} + diff --git a/surfsense_web/components/shared/ExportMenuItems.tsx b/surfsense_web/components/shared/ExportMenuItems.tsx new file mode 100644 index 000000000..69833a195 --- /dev/null +++ b/surfsense_web/components/shared/ExportMenuItems.tsx @@ -0,0 +1,142 @@ +"use client"; + +import { Loader2 } from "lucide-react"; +import { DropdownMenuItem, DropdownMenuLabel, DropdownMenuSeparator } from "@/components/ui/dropdown-menu"; +import { ContextMenuItem } from "@/components/ui/context-menu"; + +export const EXPORT_FILE_EXTENSIONS: Record = { + pdf: "pdf", + docx: "docx", + html: "html", + latex: "tex", + epub: "epub", + odt: "odt", + plain: "txt", + md: "md", +}; + +interface ExportMenuItemsProps { + onExport: (format: string) => void; + exporting: string | null; + /** Hide server-side formats (PDF, DOCX, etc.) — only show md */ + showAllFormats?: boolean; +} + +export function ExportDropdownItems({ + onExport, + exporting, + showAllFormats = true, +}: ExportMenuItemsProps) { + const handle = (format: string) => (e: React.MouseEvent) => { + e.stopPropagation(); + onExport(format); + }; + + return ( + <> + {showAllFormats && ( + <> + + Documents + + + {exporting === "pdf" && } + PDF (.pdf) + + + {exporting === "docx" && } + Word (.docx) + + + {exporting === "odt" && } + OpenDocument (.odt) + + + + Web & E-Book + + + {exporting === "html" && } + HTML (.html) + + + {exporting === "epub" && } + EPUB (.epub) + + + + Source & Plain + + + {exporting === "latex" && } + LaTeX (.tex) + + + )} + + {exporting === "md" && } + Markdown (.md) + + {showAllFormats && ( + + {exporting === "plain" && } + Plain Text (.txt) + + )} + + ); +} + +export function ExportContextItems({ + onExport, + exporting, + showAllFormats = true, +}: ExportMenuItemsProps) { + const handle = (format: string) => (e: React.MouseEvent) => { + e.stopPropagation(); + onExport(format); + }; + + return ( + <> + {showAllFormats && ( + <> + + {exporting === "pdf" && } + PDF (.pdf) + + + {exporting === "docx" && } + Word (.docx) + + + {exporting === "odt" && } + OpenDocument (.odt) + + + {exporting === "html" && } + HTML (.html) + + + {exporting === "epub" && } + EPUB (.epub) + + + {exporting === "latex" && } + LaTeX (.tex) + + + )} + + {exporting === "md" && } + Markdown (.md) + + {showAllFormats && ( + + {exporting === "plain" && } + Plain Text (.txt) + + )} + + ); +}