From ce40da80ea9f2a2afd02bd44aeb5390ac565fcff Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 4 Apr 2026 02:51:28 +0530 Subject: [PATCH 1/6] feat: implement page limit estimation and enforcement in file based connector indexers - Added a static method `estimate_pages_from_metadata` to `PageLimitService` for estimating page counts based on file metadata. - Integrated page limit checks in Google Drive, Dropbox, and OneDrive indexers to prevent exceeding user quotas during file indexing. - Updated relevant indexing methods to utilize the new page estimation logic and enforce limits accordingly. - Enhanced tests for page limit functionality, ensuring accurate estimation and enforcement across different file types. --- .../app/services/page_limit_service.py | 239 +++---- .../app/tasks/connector_indexers/base.py | 1 - .../connector_indexers/dropbox_indexer.py | 52 ++ .../google_drive_indexer.py | 88 +++ .../connector_indexers/onedrive_indexer.py | 80 +++ .../integration/document_upload/conftest.py | 63 ++ .../test_google_drive_parallel.py | 27 +- .../connector_indexers/test_page_limits.py | 648 ++++++++++++++++++ 8 files changed, 1041 insertions(+), 157 deletions(-) create mode 100644 surfsense_backend/tests/unit/connector_indexers/test_page_limits.py diff --git a/surfsense_backend/app/services/page_limit_service.py b/surfsense_backend/app/services/page_limit_service.py index 080d05b5d..ea22067be 100644 --- a/surfsense_backend/app/services/page_limit_service.py +++ b/surfsense_backend/app/services/page_limit_service.py @@ -3,7 +3,7 @@ Service for managing user page limits for ETL services. """ import os -from pathlib import Path +from pathlib import Path, PurePosixPath from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -223,10 +223,91 @@ class PageLimitService: # Estimate ~2000 characters per page return max(1, content_length // 2000) + @staticmethod + def estimate_pages_from_metadata( + file_name_or_ext: str, file_size: int | str | None = None + ) -> int: + """Size-based page estimation from file name/extension and byte size. + + Pure function — no file I/O, no database access. Used by cloud + connectors (which only have API metadata) and as the internal + fallback for :meth:`estimate_pages_before_processing`. + + ``file_name_or_ext`` can be a full filename (``"report.pdf"``) or + a bare extension (``".pdf"``). ``file_size`` may be an int, a + stringified int from a cloud API, or *None*. + """ + if file_size is not None: + try: + file_size = int(file_size) + except (ValueError, TypeError): + file_size = 0 + else: + file_size = 0 + + if file_size <= 0: + return 1 + + ext = PurePosixPath(file_name_or_ext).suffix.lower() if file_name_or_ext else "" + if not ext and file_name_or_ext.startswith("."): + ext = file_name_or_ext.lower() + file_ext = ext + + if file_ext == ".pdf": + return max(1, file_size // (100 * 1024)) + + if file_ext in { + ".doc", ".docx", ".docm", ".dot", ".dotm", + ".odt", ".ott", ".sxw", ".stw", ".uot", + ".rtf", ".pages", ".wpd", ".wps", + ".abw", ".zabw", ".cwk", ".hwp", ".lwp", + ".mcw", ".mw", ".sdw", ".vor", + }: + return max(1, file_size // (50 * 1024)) + + if file_ext in { + ".ppt", ".pptx", ".pptm", ".pot", ".potx", + ".odp", ".otp", ".sxi", ".sti", ".uop", + ".key", ".sda", ".sdd", ".sdp", + }: + return max(1, file_size // (200 * 1024)) + + if file_ext in { + ".xls", ".xlsx", ".xlsm", ".xlsb", ".xlw", ".xlr", + ".ods", ".ots", ".fods", ".numbers", + ".123", ".wk1", ".wk2", ".wk3", ".wk4", ".wks", + ".wb1", ".wb2", ".wb3", ".wq1", ".wq2", + ".csv", ".tsv", ".slk", ".sylk", ".dif", ".dbf", + ".prn", ".qpw", ".602", ".et", ".eth", + }: + return max(1, file_size // (100 * 1024)) + + if file_ext in {".epub"}: + return max(1, file_size // (50 * 1024)) + + if file_ext in {".txt", ".log", ".md", ".markdown", ".htm", ".html", ".xml"}: + return max(1, file_size // 3000) + + if file_ext in { + ".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", + ".webp", ".svg", ".cgm", ".odg", ".pbd", + }: + return 1 + + if file_ext in {".mp3", ".m4a", ".wav", ".mpga"}: + return max(1, file_size // (1024 * 1024)) + + if file_ext in {".mp4", ".mpeg", ".webm"}: + return max(1, file_size // (5 * 1024 * 1024)) + + return max(1, file_size // (80 * 1024)) + def estimate_pages_before_processing(self, file_path: str) -> int: """ - Estimate page count from file before processing (to avoid unnecessary API calls). - This is called BEFORE sending to ETL services to prevent cost on rejected files. + Estimate page count from a local file before processing. + + For PDFs, attempts to read the actual page count via pypdf. + For everything else, delegates to :meth:`estimate_pages_from_metadata`. Args: file_path: Path to the file @@ -240,7 +321,6 @@ class PageLimitService: file_ext = Path(file_path).suffix.lower() file_size = os.path.getsize(file_path) - # PDF files - try to get actual page count if file_ext == ".pdf": try: import pypdf @@ -249,153 +329,6 @@ class PageLimitService: pdf_reader = pypdf.PdfReader(f) return len(pdf_reader.pages) except Exception: - # If PDF reading fails, fall back to size estimation - # Typical PDF: ~100KB per page (conservative estimate) - return max(1, file_size // (100 * 1024)) + pass # fall through to size-based estimation - # Word Processing Documents - # Microsoft Word, LibreOffice Writer, WordPerfect, Pages, etc. - elif file_ext in [ - ".doc", - ".docx", - ".docm", - ".dot", - ".dotm", # Microsoft Word - ".odt", - ".ott", - ".sxw", - ".stw", - ".uot", # OpenDocument/StarOffice Writer - ".rtf", # Rich Text Format - ".pages", # Apple Pages - ".wpd", - ".wps", # WordPerfect, Microsoft Works - ".abw", - ".zabw", # AbiWord - ".cwk", - ".hwp", - ".lwp", - ".mcw", - ".mw", - ".sdw", - ".vor", # Other word processors - ]: - # Typical word document: ~50KB per page (conservative) - return max(1, file_size // (50 * 1024)) - - # Presentation Documents - # PowerPoint, Impress, Keynote, etc. - elif file_ext in [ - ".ppt", - ".pptx", - ".pptm", - ".pot", - ".potx", # Microsoft PowerPoint - ".odp", - ".otp", - ".sxi", - ".sti", - ".uop", # OpenDocument/StarOffice Impress - ".key", # Apple Keynote - ".sda", - ".sdd", - ".sdp", # StarOffice Draw/Impress - ]: - # Typical presentation: ~200KB per slide (conservative) - return max(1, file_size // (200 * 1024)) - - # Spreadsheet Documents - # Excel, Calc, Numbers, Lotus, etc. - elif file_ext in [ - ".xls", - ".xlsx", - ".xlsm", - ".xlsb", - ".xlw", - ".xlr", # Microsoft Excel - ".ods", - ".ots", - ".fods", # OpenDocument Spreadsheet - ".numbers", # Apple Numbers - ".123", - ".wk1", - ".wk2", - ".wk3", - ".wk4", - ".wks", # Lotus 1-2-3 - ".wb1", - ".wb2", - ".wb3", - ".wq1", - ".wq2", # Quattro Pro - ".csv", - ".tsv", - ".slk", - ".sylk", - ".dif", - ".dbf", - ".prn", - ".qpw", # Data formats - ".602", - ".et", - ".eth", # Other spreadsheets - ]: - # Spreadsheets typically have 1 sheet = 1 page for ETL - # Conservative: ~100KB per sheet - return max(1, file_size // (100 * 1024)) - - # E-books - elif file_ext in [".epub"]: - # E-books vary widely, estimate by size - # Typical e-book: ~50KB per page - return max(1, file_size // (50 * 1024)) - - # Plain Text and Markup Files - elif file_ext in [ - ".txt", - ".log", # Plain text - ".md", - ".markdown", # Markdown - ".htm", - ".html", - ".xml", # Markup - ]: - # Plain text: ~3000 bytes per page - return max(1, file_size // 3000) - - # Image Files - # Each image is typically processed as 1 page - elif file_ext in [ - ".jpg", - ".jpeg", # JPEG - ".png", # PNG - ".gif", # GIF - ".bmp", # Bitmap - ".tiff", # TIFF - ".webp", # WebP - ".svg", # SVG - ".cgm", # Computer Graphics Metafile - ".odg", - ".pbd", # OpenDocument Graphics - ]: - # Each image = 1 page - return 1 - - # Audio Files (transcription = typically 1 page per minute) - # Note: These should be handled by audio transcription flow, not ETL - elif file_ext in [".mp3", ".m4a", ".wav", ".mpga"]: - # Audio files: estimate based on duration - # Fallback: ~1MB per minute of audio, 1 page per minute transcript - return max(1, file_size // (1024 * 1024)) - - # Video Files (typically not processed for pages, but just in case) - elif file_ext in [".mp4", ".mpeg", ".webm"]: - # Video files: very rough estimate - # Typically wouldn't be page-based, but use conservative estimate - return max(1, file_size // (5 * 1024 * 1024)) - - # Other/Unknown Document Types - else: - # Conservative estimate: ~80KB per page - # This catches: .sgl, .sxg, .uof, .uos1, .uos2, .web, and any future formats - return max(1, file_size // (80 * 1024)) + return self.estimate_pages_from_metadata(file_ext, file_size) diff --git a/surfsense_backend/app/tasks/connector_indexers/base.py b/surfsense_backend/app/tasks/connector_indexers/base.py index ffc8ab72e..6b4bed4b5 100644 --- a/surfsense_backend/app/tasks/connector_indexers/base.py +++ b/surfsense_backend/app/tasks/connector_indexers/base.py @@ -4,7 +4,6 @@ Base functionality and shared imports for connector indexers. import logging from datetime import UTC, datetime, timedelta - from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select diff --git a/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py b/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py index 1b039add7..87b3c55df 100644 --- a/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/dropbox_indexer.py @@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import compute_identifier_hash from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -278,6 +279,12 @@ async def _index_full_scan( }, ) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used + batch_estimated_pages = 0 + page_limit_reached = False + renamed_count = 0 skipped = 0 files_to_download: list[dict] = [] @@ -307,6 +314,21 @@ async def _index_full_scan( elif skip_item(file): skipped += 1 continue + + file_pages = PageLimitService.estimate_pages_from_metadata( + file.get("name", ""), file.get("size") + ) + if batch_estimated_pages + file_pages > remaining_quota: + if not page_limit_reached: + logger.warning( + "Page limit reached during Dropbox full scan, " + "skipping remaining files" + ) + page_limit_reached = True + skipped += 1 + continue + + batch_estimated_pages += file_pages files_to_download.append(file) batch_indexed, failed = await _download_and_index( @@ -320,6 +342,14 @@ async def _index_full_scan( on_heartbeat=on_heartbeat_callback, ) + if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0: + pages_to_deduct = max( + 1, batch_estimated_pages * batch_indexed // len(files_to_download) + ) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) + indexed = renamed_count + batch_indexed logger.info( f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed" @@ -340,6 +370,11 @@ async def _index_selected_files( on_heartbeat: HeartbeatCallbackType | None = None, ) -> tuple[int, int, list[str]]: """Index user-selected files using the parallel pipeline.""" + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used + batch_estimated_pages = 0 + files_to_download: list[dict] = [] errors: list[str] = [] renamed_count = 0 @@ -364,6 +399,15 @@ async def _index_selected_files( skipped += 1 continue + file_pages = PageLimitService.estimate_pages_from_metadata( + file.get("name", ""), file.get("size") + ) + if batch_estimated_pages + file_pages > remaining_quota: + display = file_name or file_path + errors.append(f"File '{display}': page limit would be exceeded") + continue + + batch_estimated_pages += file_pages files_to_download.append(file) batch_indexed, _failed = await _download_and_index( @@ -377,6 +421,14 @@ async def _index_selected_files( on_heartbeat=on_heartbeat, ) + if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0: + pages_to_deduct = max( + 1, batch_estimated_pages * batch_indexed // len(files_to_download) + ) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) + return renamed_count + batch_indexed, skipped, errors diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index b03d305f7..5e9e0f62f 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -34,6 +34,7 @@ from app.indexing_pipeline.indexing_pipeline_service import ( PlaceholderInfo, ) from app.services.llm_service import get_user_long_context_llm +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -327,6 +328,12 @@ async def _process_single_file( return 1, 0, 0 return 0, 1, 0 + page_limit_service = PageLimitService(session) + estimated_pages = PageLimitService.estimate_pages_from_metadata( + file_name, file.get("size") + ) + await page_limit_service.check_page_limit(user_id, estimated_pages) + markdown, drive_metadata, error = await download_and_extract_content( drive_client, file ) @@ -363,6 +370,9 @@ async def _process_single_file( ) await pipeline.index(document, connector_doc, user_llm) + await page_limit_service.update_page_usage( + user_id, estimated_pages, allow_exceed=True + ) logger.info(f"Successfully indexed Google Drive file: {file_name}") return 1, 0, 0 @@ -466,6 +476,11 @@ async def _index_selected_files( Returns (indexed_count, skipped_count, errors). """ + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used + batch_estimated_pages = 0 + files_to_download: list[dict] = [] errors: list[str] = [] renamed_count = 0 @@ -486,6 +501,15 @@ async def _index_selected_files( skipped += 1 continue + file_pages = PageLimitService.estimate_pages_from_metadata( + file.get("name", ""), file.get("size") + ) + if batch_estimated_pages + file_pages > remaining_quota: + display = file_name or file_id + errors.append(f"File '{display}': page limit would be exceeded") + continue + + batch_estimated_pages += file_pages files_to_download.append(file) await _create_drive_placeholders( @@ -507,6 +531,14 @@ async def _index_selected_files( on_heartbeat=on_heartbeat, ) + if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0: + pages_to_deduct = max( + 1, batch_estimated_pages * batch_indexed // len(files_to_download) + ) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) + return renamed_count + batch_indexed, skipped, errors @@ -545,6 +577,12 @@ async def _index_full_scan( # ------------------------------------------------------------------ # Phase 1 (serial): collect files, run skip checks, track renames # ------------------------------------------------------------------ + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used + batch_estimated_pages = 0 + page_limit_reached = False + renamed_count = 0 skipped = 0 files_processed = 0 @@ -593,6 +631,20 @@ async def _index_full_scan( skipped += 1 continue + file_pages = PageLimitService.estimate_pages_from_metadata( + file.get("name", ""), file.get("size") + ) + if batch_estimated_pages + file_pages > remaining_quota: + if not page_limit_reached: + logger.warning( + "Page limit reached during Google Drive full scan, " + "skipping remaining files" + ) + page_limit_reached = True + skipped += 1 + continue + + batch_estimated_pages += file_pages files_to_download.append(file) page_token = next_token @@ -636,6 +688,14 @@ async def _index_full_scan( on_heartbeat=on_heartbeat_callback, ) + if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0: + pages_to_deduct = max( + 1, batch_estimated_pages * batch_indexed // len(files_to_download) + ) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) + indexed = renamed_count + batch_indexed logger.info( f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed" @@ -686,6 +746,12 @@ async def _index_with_delta_sync( # ------------------------------------------------------------------ # Phase 1 (serial): handle removals, collect files for download # ------------------------------------------------------------------ + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used + batch_estimated_pages = 0 + page_limit_reached = False + renamed_count = 0 skipped = 0 files_to_download: list[dict] = [] @@ -715,6 +781,20 @@ async def _index_with_delta_sync( skipped += 1 continue + file_pages = PageLimitService.estimate_pages_from_metadata( + file.get("name", ""), file.get("size") + ) + if batch_estimated_pages + file_pages > remaining_quota: + if not page_limit_reached: + logger.warning( + "Page limit reached during Google Drive delta sync, " + "skipping remaining files" + ) + page_limit_reached = True + skipped += 1 + continue + + batch_estimated_pages += file_pages files_to_download.append(file) # ------------------------------------------------------------------ @@ -742,6 +822,14 @@ async def _index_with_delta_sync( on_heartbeat=on_heartbeat_callback, ) + if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0: + pages_to_deduct = max( + 1, batch_estimated_pages * batch_indexed // len(files_to_download) + ) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) + indexed = renamed_count + batch_indexed logger.info( f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed" diff --git a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py index 748cb0988..2301b6260 100644 --- a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py @@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import compute_identifier_hash from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService from app.services.llm_service import get_user_long_context_llm +from app.services.page_limit_service import PageLimitService from app.services.task_logging_service import TaskLoggingService from app.tasks.connector_indexers.base import ( check_document_by_unique_identifier, @@ -291,6 +292,11 @@ async def _index_selected_files( on_heartbeat: HeartbeatCallbackType | None = None, ) -> tuple[int, int, list[str]]: """Index user-selected files using the parallel pipeline.""" + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used + batch_estimated_pages = 0 + files_to_download: list[dict] = [] errors: list[str] = [] renamed_count = 0 @@ -311,6 +317,15 @@ async def _index_selected_files( skipped += 1 continue + file_pages = PageLimitService.estimate_pages_from_metadata( + file.get("name", ""), file.get("size") + ) + if batch_estimated_pages + file_pages > remaining_quota: + display = file_name or file_id + errors.append(f"File '{display}': page limit would be exceeded") + continue + + batch_estimated_pages += file_pages files_to_download.append(file) batch_indexed, _failed = await _download_and_index( @@ -324,6 +339,14 @@ async def _index_selected_files( on_heartbeat=on_heartbeat, ) + if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0: + pages_to_deduct = max( + 1, batch_estimated_pages * batch_indexed // len(files_to_download) + ) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) + return renamed_count + batch_indexed, skipped, errors @@ -358,6 +381,12 @@ async def _index_full_scan( }, ) + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used + batch_estimated_pages = 0 + page_limit_reached = False + renamed_count = 0 skipped = 0 files_to_download: list[dict] = [] @@ -383,6 +412,21 @@ async def _index_full_scan( else: skipped += 1 continue + + file_pages = PageLimitService.estimate_pages_from_metadata( + file.get("name", ""), file.get("size") + ) + if batch_estimated_pages + file_pages > remaining_quota: + if not page_limit_reached: + logger.warning( + "Page limit reached during OneDrive full scan, " + "skipping remaining files" + ) + page_limit_reached = True + skipped += 1 + continue + + batch_estimated_pages += file_pages files_to_download.append(file) batch_indexed, failed = await _download_and_index( @@ -396,6 +440,14 @@ async def _index_full_scan( on_heartbeat=on_heartbeat_callback, ) + if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0: + pages_to_deduct = max( + 1, batch_estimated_pages * batch_indexed // len(files_to_download) + ) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) + indexed = renamed_count + batch_indexed logger.info( f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed" @@ -441,6 +493,12 @@ async def _index_with_delta_sync( logger.info(f"Processing {len(changes)} delta changes") + page_limit_service = PageLimitService(session) + pages_used, pages_limit = await page_limit_service.get_page_usage(user_id) + remaining_quota = pages_limit - pages_used + batch_estimated_pages = 0 + page_limit_reached = False + renamed_count = 0 skipped = 0 files_to_download: list[dict] = [] @@ -471,6 +529,20 @@ async def _index_with_delta_sync( skipped += 1 continue + file_pages = PageLimitService.estimate_pages_from_metadata( + change.get("name", ""), change.get("size") + ) + if batch_estimated_pages + file_pages > remaining_quota: + if not page_limit_reached: + logger.warning( + "Page limit reached during OneDrive delta sync, " + "skipping remaining files" + ) + page_limit_reached = True + skipped += 1 + continue + + batch_estimated_pages += file_pages files_to_download.append(change) batch_indexed, failed = await _download_and_index( @@ -484,6 +556,14 @@ async def _index_with_delta_sync( on_heartbeat=on_heartbeat_callback, ) + if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0: + pages_to_deduct = max( + 1, batch_estimated_pages * batch_indexed // len(files_to_download) + ) + await page_limit_service.update_page_usage( + user_id, pages_to_deduct, allow_exceed=True + ) + indexed = renamed_count + batch_indexed logger.info( f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed" diff --git a/surfsense_backend/tests/integration/document_upload/conftest.py b/surfsense_backend/tests/integration/document_upload/conftest.py index 1f1c7df59..6640fefdf 100644 --- a/surfsense_backend/tests/integration/document_upload/conftest.py +++ b/surfsense_backend/tests/integration/document_upload/conftest.py @@ -3,6 +3,7 @@ Prerequisites: PostgreSQL + pgvector only. External system boundaries are mocked: + - ETL parsing — LlamaParse (external API) and Docling (heavy library) - LLM summarization, text embedding, text chunking (external APIs) - Redis heartbeat (external infrastructure) - Task dispatch is swapped via DI (InlineTaskDispatcher) @@ -11,6 +12,7 @@ External system boundaries are mocked: from __future__ import annotations import contextlib +import os from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock @@ -298,3 +300,64 @@ def _mock_redis_heartbeat(monkeypatch): "app.tasks.celery_tasks.document_tasks._run_heartbeat_loop", AsyncMock(), ) + + +@pytest.fixture(autouse=True) +def _mock_etl_parsing(monkeypatch): + """Mock ETL parsing services — LlamaParse and Docling are external boundaries. + + Preserves the real contract: empty/corrupt files raise an error just like + the actual services would, so tests covering failure paths keep working. + """ + + _MOCK_MARKDOWN = "# Mocked Document\n\nThis is mocked ETL content." + + def _reject_empty(file_path: str) -> None: + if os.path.getsize(file_path) == 0: + raise RuntimeError(f"Cannot parse empty file: {file_path}") + + # -- LlamaParse mock (external API) -------------------------------- + + class _FakeMarkdownDoc: + def __init__(self, text: str): + self.text = text + + class _FakeLlamaParseResult: + async def aget_markdown_documents(self, *, split_by_page=False): + return [_FakeMarkdownDoc(_MOCK_MARKDOWN)] + + async def _fake_llamacloud_parse(**kwargs): + _reject_empty(kwargs["file_path"]) + return _FakeLlamaParseResult() + + monkeypatch.setattr( + "app.tasks.document_processors.file_processors.parse_with_llamacloud_retry", + _fake_llamacloud_parse, + ) + + # -- Docling mock (heavy library boundary) ------------------------- + + async def _fake_docling_parse(file_path: str, filename: str): + _reject_empty(file_path) + return _MOCK_MARKDOWN + + monkeypatch.setattr( + "app.tasks.document_processors.file_processors.parse_with_docling", + _fake_docling_parse, + ) + + class _FakeDoclingResult: + class document: + @staticmethod + def export_to_markdown(): + return _MOCK_MARKDOWN + + class _FakeDocumentConverter: + def convert(self, file_path): + _reject_empty(file_path) + return _FakeDoclingResult() + + monkeypatch.setattr( + "docling.document_converter.DocumentConverter", + _FakeDocumentConverter, + ) diff --git a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py index 3fe8a183d..20bd3f3d6 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_google_drive_parallel.py @@ -248,12 +248,33 @@ def _folder_dict(file_id: str, name: str) -> dict: } +def _make_page_limit_session(pages_used=0, pages_limit=999_999): + """Build a mock DB session that real PageLimitService can operate against.""" + + class _FakeUser: + def __init__(self, pu, pl): + self.pages_used = pu + self.pages_limit = pl + + fake_user = _FakeUser(pages_used, pages_limit) + session = AsyncMock() + + def _make_result(*_a, **_kw): + r = MagicMock() + r.first.return_value = (fake_user.pages_used, fake_user.pages_limit) + r.unique.return_value.scalar_one_or_none.return_value = fake_user + return r + + session.execute = AsyncMock(side_effect=_make_result) + return session, fake_user + + @pytest.fixture def full_scan_mocks(mock_drive_client, monkeypatch): """Wire up all mocks needed to call _index_full_scan in isolation.""" import app.tasks.connector_indexers.google_drive_indexer as _mod - mock_session = AsyncMock() + mock_session, _ = _make_page_limit_session() mock_connector = MagicMock() mock_task_logger = MagicMock() mock_task_logger.log_task_progress = AsyncMock() @@ -472,7 +493,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch): AsyncMock(return_value=MagicMock()), ) - mock_session = AsyncMock() + mock_session, _ = _make_page_limit_session() mock_task_logger = MagicMock() mock_task_logger.log_task_progress = AsyncMock() @@ -512,7 +533,7 @@ def selected_files_mocks(mock_drive_client, monkeypatch): """Wire up mocks for _index_selected_files tests.""" import app.tasks.connector_indexers.google_drive_indexer as _mod - mock_session = AsyncMock() + mock_session, _ = _make_page_limit_session() get_file_results: dict[str, tuple[dict | None, str | None]] = {} diff --git a/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py b/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py new file mode 100644 index 000000000..1c93965f3 --- /dev/null +++ b/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py @@ -0,0 +1,648 @@ +"""Tests for page limit enforcement in connector indexers. + +Covers: + A) PageLimitService.estimate_pages_from_metadata — pure function (no mocks) + B) Page-limit quota gating in _index_selected_files tested through the + real PageLimitService with a mock DB session (system boundary). + Google Drive is the primary, with OneDrive/Dropbox smoke tests. +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.services.page_limit_service import PageLimitService + +pytestmark = pytest.mark.unit + +_USER_ID = "00000000-0000-0000-0000-000000000001" +_CONNECTOR_ID = 42 +_SEARCH_SPACE_ID = 1 + + +# =================================================================== +# A) PageLimitService.estimate_pages_from_metadata — pure function +# No mocks: it's a staticmethod with no I/O. +# =================================================================== + + +class TestEstimatePagesFromMetadata: + """Vertical slices for the page estimation staticmethod.""" + + def test_pdf_100kb_returns_1(self): + assert PageLimitService.estimate_pages_from_metadata(".pdf", 100 * 1024) == 1 + + def test_pdf_500kb_returns_5(self): + assert PageLimitService.estimate_pages_from_metadata(".pdf", 500 * 1024) == 5 + + def test_pdf_1mb(self): + assert PageLimitService.estimate_pages_from_metadata(".pdf", 1024 * 1024) == 10 + + def test_docx_50kb_returns_1(self): + assert PageLimitService.estimate_pages_from_metadata(".docx", 50 * 1024) == 1 + + def test_docx_200kb(self): + assert PageLimitService.estimate_pages_from_metadata(".docx", 200 * 1024) == 4 + + def test_pptx_uses_200kb_per_page(self): + assert PageLimitService.estimate_pages_from_metadata(".pptx", 600 * 1024) == 3 + + def test_xlsx_uses_100kb_per_page(self): + assert PageLimitService.estimate_pages_from_metadata(".xlsx", 300 * 1024) == 3 + + def test_txt_uses_3000_bytes_per_page(self): + assert PageLimitService.estimate_pages_from_metadata(".txt", 9000) == 3 + + def test_image_always_returns_1(self): + for ext in (".jpg", ".png", ".gif", ".webp"): + assert PageLimitService.estimate_pages_from_metadata(ext, 5_000_000) == 1 + + def test_audio_uses_1mb_per_page(self): + assert PageLimitService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3 + + def test_video_uses_5mb_per_page(self): + assert PageLimitService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3 + + def test_unknown_ext_uses_80kb_per_page(self): + assert PageLimitService.estimate_pages_from_metadata(".xyz", 160 * 1024) == 2 + + def test_zero_size_returns_1(self): + assert PageLimitService.estimate_pages_from_metadata(".pdf", 0) == 1 + + def test_negative_size_returns_1(self): + assert PageLimitService.estimate_pages_from_metadata(".pdf", -500) == 1 + + def test_minimum_is_always_1(self): + assert PageLimitService.estimate_pages_from_metadata(".pdf", 50) == 1 + + def test_epub_uses_50kb_per_page(self): + assert PageLimitService.estimate_pages_from_metadata(".epub", 250 * 1024) == 5 + + +# =================================================================== +# B) Page-limit enforcement in connector indexers +# System boundary mocked: DB session (for PageLimitService) +# System boundary mocked: external API clients, download/ETL +# NOT mocked: PageLimitService itself (our own code) +# =================================================================== + + +class _FakeUser: + """Stands in for the User ORM model at the DB boundary.""" + + def __init__(self, pages_used: int = 0, pages_limit: int = 100): + self.pages_used = pages_used + self.pages_limit = pages_limit + + +def _make_page_limit_session(pages_used: int = 0, pages_limit: int = 100): + """Build a mock DB session that real PageLimitService can operate against. + + Every ``session.execute()`` returns a result compatible with both + ``get_page_usage`` (.first() → tuple) and ``update_page_usage`` + (.unique().scalar_one_or_none() → User-like). + """ + fake_user = _FakeUser(pages_used, pages_limit) + session = AsyncMock() + + def _make_result(*_args, **_kwargs): + result = MagicMock() + result.first.return_value = (fake_user.pages_used, fake_user.pages_limit) + result.unique.return_value.scalar_one_or_none.return_value = fake_user + return result + + session.execute = AsyncMock(side_effect=_make_result) + return session, fake_user + + +def _make_gdrive_file(file_id: str, name: str, size: int = 80 * 1024) -> dict: + return { + "id": file_id, + "name": name, + "mimeType": "application/octet-stream", + "size": str(size), + } + + +# --------------------------------------------------------------------------- +# Google Drive: _index_selected_files +# --------------------------------------------------------------------------- + + +@pytest.fixture +def gdrive_selected_mocks(monkeypatch): + """Mocks for Google Drive _index_selected_files — only system boundaries.""" + import app.tasks.connector_indexers.google_drive_indexer as _mod + + session, fake_user = _make_page_limit_session(0, 100) + + get_file_results: dict[str, tuple[dict | None, str | None]] = {} + + async def _fake_get_file(client, file_id): + return get_file_results.get(file_id, (None, f"Not configured: {file_id}")) + + monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file) + monkeypatch.setattr( + _mod, "_should_skip_file", AsyncMock(return_value=(False, None)) + ) + + download_and_index_mock = AsyncMock(return_value=(0, 0)) + monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) + + pipeline_mock = MagicMock() + pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0) + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock) + ) + + return { + "mod": _mod, + "session": session, + "fake_user": fake_user, + "get_file_results": get_file_results, + "download_and_index_mock": download_and_index_mock, + } + + +async def _run_gdrive_selected(mocks, file_ids): + from app.tasks.connector_indexers.google_drive_indexer import ( + _index_selected_files, + ) + + return await _index_selected_files( + MagicMock(), + mocks["session"], + file_ids, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, + ) + + +async def test_gdrive_files_within_quota_are_downloaded(gdrive_selected_mocks): + """Files whose cumulative estimated pages fit within remaining quota + are sent to _download_and_index.""" + m = gdrive_selected_mocks + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 + + for fid in ("f1", "f2", "f3"): + m["get_file_results"][fid] = ( + _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + ) + m["download_and_index_mock"].return_value = (3, 0) + + indexed, _skipped, errors = await _run_gdrive_selected( + m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")] + ) + + assert indexed == 3 + assert errors == [] + call_files = m["download_and_index_mock"].call_args[0][2] + assert len(call_files) == 3 + + +async def test_gdrive_files_exceeding_quota_rejected(gdrive_selected_mocks): + """Files whose pages would exceed remaining quota are rejected.""" + m = gdrive_selected_mocks + m["fake_user"].pages_used = 98 + m["fake_user"].pages_limit = 100 + + m["get_file_results"]["big"] = ( + _make_gdrive_file("big", "huge.pdf", size=500 * 1024), None, + ) + + indexed, _skipped, errors = await _run_gdrive_selected(m, [("big", "huge.pdf")]) + + assert indexed == 0 + assert len(errors) == 1 + assert "page limit" in errors[0].lower() + + +async def test_gdrive_quota_mix_partial_indexing(gdrive_selected_mocks): + """3rd file pushes over quota → only first two indexed.""" + m = gdrive_selected_mocks + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 2 + + for fid in ("f1", "f2", "f3"): + m["get_file_results"][fid] = ( + _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + ) + m["download_and_index_mock"].return_value = (2, 0) + + indexed, _skipped, errors = await _run_gdrive_selected( + m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")] + ) + + assert indexed == 2 + assert len(errors) == 1 + call_files = m["download_and_index_mock"].call_args[0][2] + assert {f["id"] for f in call_files} == {"f1", "f2"} + + +async def test_gdrive_proportional_page_deduction(gdrive_selected_mocks): + """Pages deducted are proportional to successfully indexed files.""" + m = gdrive_selected_mocks + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 + + for fid in ("f1", "f2", "f3", "f4"): + m["get_file_results"][fid] = ( + _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + ) + m["download_and_index_mock"].return_value = (2, 2) + + await _run_gdrive_selected( + m, + [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz"), ("f4", "f4.xyz")], + ) + + assert m["fake_user"].pages_used == 2 + + +async def test_gdrive_no_deduction_when_nothing_indexed(gdrive_selected_mocks): + """If batch_indexed == 0, user's pages_used stays unchanged.""" + m = gdrive_selected_mocks + m["fake_user"].pages_used = 5 + m["fake_user"].pages_limit = 100 + + m["get_file_results"]["f1"] = ( + _make_gdrive_file("f1", "f1.xyz", size=80 * 1024), None, + ) + m["download_and_index_mock"].return_value = (0, 1) + + await _run_gdrive_selected(m, [("f1", "f1.xyz")]) + + assert m["fake_user"].pages_used == 5 + + +async def test_gdrive_zero_quota_rejects_all(gdrive_selected_mocks): + """When pages_used == pages_limit, every file is rejected.""" + m = gdrive_selected_mocks + m["fake_user"].pages_used = 100 + m["fake_user"].pages_limit = 100 + + for fid in ("f1", "f2"): + m["get_file_results"][fid] = ( + _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + ) + + indexed, _skipped, errors = await _run_gdrive_selected( + m, [("f1", "f1.xyz"), ("f2", "f2.xyz")] + ) + + assert indexed == 0 + assert len(errors) == 2 + + +# --------------------------------------------------------------------------- +# Google Drive: _index_full_scan +# --------------------------------------------------------------------------- + + +@pytest.fixture +def gdrive_full_scan_mocks(monkeypatch): + import app.tasks.connector_indexers.google_drive_indexer as _mod + + session, fake_user = _make_page_limit_session(0, 100) + mock_task_logger = MagicMock() + mock_task_logger.log_task_progress = AsyncMock() + + monkeypatch.setattr( + _mod, "_should_skip_file", AsyncMock(return_value=(False, None)) + ) + + download_mock = AsyncMock(return_value=([], 0)) + monkeypatch.setattr(_mod, "_download_files_parallel", download_mock) + + batch_mock = AsyncMock(return_value=([], 0, 0)) + pipeline_mock = MagicMock() + pipeline_mock.index_batch_parallel = batch_mock + pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0) + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock) + ) + monkeypatch.setattr( + _mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()) + ) + + return { + "mod": _mod, + "session": session, + "fake_user": fake_user, + "task_logger": mock_task_logger, + "download_mock": download_mock, + "batch_mock": batch_mock, + } + + +async def _run_gdrive_full_scan(mocks, max_files=500): + from app.tasks.connector_indexers.google_drive_indexer import _index_full_scan + + return await _index_full_scan( + MagicMock(), + mocks["session"], + MagicMock(), + _CONNECTOR_ID, + _SEARCH_SPACE_ID, + _USER_ID, + "folder-root", + "My Folder", + mocks["task_logger"], + MagicMock(), + max_files, + include_subfolders=False, + enable_summary=True, + ) + + +async def test_gdrive_full_scan_skips_over_quota(gdrive_full_scan_mocks, monkeypatch): + m = gdrive_full_scan_mocks + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 2 + + page_files = [ + _make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(5) + ] + monkeypatch.setattr( + m["mod"], "get_files_in_folder", + AsyncMock(return_value=(page_files, None, None)), + ) + m["download_mock"].return_value = ([], 0) + m["batch_mock"].return_value = ([], 2, 0) + + _indexed, skipped = await _run_gdrive_full_scan(m) + + call_files = m["download_mock"].call_args[0][1] + assert len(call_files) == 2 + assert skipped == 3 + + +async def test_gdrive_full_scan_deducts_after_indexing( + gdrive_full_scan_mocks, monkeypatch +): + m = gdrive_full_scan_mocks + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 + + page_files = [ + _make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(3) + ] + monkeypatch.setattr( + m["mod"], "get_files_in_folder", + AsyncMock(return_value=(page_files, None, None)), + ) + mock_docs = [MagicMock() for _ in range(3)] + m["download_mock"].return_value = (mock_docs, 0) + m["batch_mock"].return_value = ([], 3, 0) + + await _run_gdrive_full_scan(m) + + assert m["fake_user"].pages_used == 3 + + +# --------------------------------------------------------------------------- +# Google Drive: _index_with_delta_sync +# --------------------------------------------------------------------------- + + +async def test_gdrive_delta_sync_skips_over_quota(monkeypatch): + import app.tasks.connector_indexers.google_drive_indexer as _mod + + session, _ = _make_page_limit_session(0, 2) + + changes = [ + { + "fileId": f"mod{i}", + "file": _make_gdrive_file(f"mod{i}", f"mod{i}.xyz", size=80 * 1024), + } + for i in range(5) + ] + monkeypatch.setattr( + _mod, "fetch_all_changes", + AsyncMock(return_value=(changes, "new-token", None)), + ) + monkeypatch.setattr(_mod, "categorize_change", lambda change: "modified") + monkeypatch.setattr( + _mod, "_should_skip_file", AsyncMock(return_value=(False, None)) + ) + + download_mock = AsyncMock(return_value=([], 0)) + monkeypatch.setattr(_mod, "_download_files_parallel", download_mock) + + batch_mock = AsyncMock(return_value=([], 2, 0)) + pipeline_mock = MagicMock() + pipeline_mock.index_batch_parallel = batch_mock + pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0) + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock) + ) + monkeypatch.setattr( + _mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()) + ) + + mock_task_logger = MagicMock() + mock_task_logger.log_task_progress = AsyncMock() + + _indexed, skipped = await _mod._index_with_delta_sync( + MagicMock(), session, MagicMock(), + _CONNECTOR_ID, _SEARCH_SPACE_ID, _USER_ID, + "folder-root", "start-token", + mock_task_logger, MagicMock(), + max_files=500, enable_summary=True, + ) + + call_files = download_mock.call_args[0][1] + assert len(call_files) == 2 + assert skipped == 3 + + +# =================================================================== +# C) OneDrive smoke tests — verify page limit wiring +# =================================================================== + + +def _make_onedrive_file(file_id: str, name: str, size: int = 80 * 1024) -> dict: + return { + "id": file_id, + "name": name, + "file": {"mimeType": "application/octet-stream"}, + "size": str(size), + "lastModifiedDateTime": "2026-01-01T00:00:00Z", + } + + +@pytest.fixture +def onedrive_selected_mocks(monkeypatch): + import app.tasks.connector_indexers.onedrive_indexer as _mod + + session, fake_user = _make_page_limit_session(0, 100) + + get_file_results: dict[str, tuple[dict | None, str | None]] = {} + + async def _fake_get_file(client, file_id): + return get_file_results.get(file_id, (None, f"Not found: {file_id}")) + + monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file) + monkeypatch.setattr( + _mod, "_should_skip_file", AsyncMock(return_value=(False, None)) + ) + + download_and_index_mock = AsyncMock(return_value=(0, 0)) + monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) + + pipeline_mock = MagicMock() + pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0) + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock) + ) + + return { + "session": session, + "fake_user": fake_user, + "get_file_results": get_file_results, + "download_and_index_mock": download_and_index_mock, + } + + +async def _run_onedrive_selected(mocks, file_ids): + from app.tasks.connector_indexers.onedrive_indexer import _index_selected_files + + return await _index_selected_files( + MagicMock(), mocks["session"], file_ids, + connector_id=_CONNECTOR_ID, search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, enable_summary=True, + ) + + +async def test_onedrive_over_quota_rejected(onedrive_selected_mocks): + """OneDrive: files exceeding quota produce errors, not downloads.""" + m = onedrive_selected_mocks + m["fake_user"].pages_used = 99 + m["fake_user"].pages_limit = 100 + + m["get_file_results"]["big"] = ( + _make_onedrive_file("big", "huge.pdf", size=500 * 1024), None, + ) + + indexed, _skipped, errors = await _run_onedrive_selected(m, [("big", "huge.pdf")]) + + assert indexed == 0 + assert len(errors) == 1 + assert "page limit" in errors[0].lower() + + +async def test_onedrive_deducts_after_success(onedrive_selected_mocks): + """OneDrive: pages_used increases after successful indexing.""" + m = onedrive_selected_mocks + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 + + for fid in ("f1", "f2"): + m["get_file_results"][fid] = ( + _make_onedrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + ) + m["download_and_index_mock"].return_value = (2, 0) + + await _run_onedrive_selected(m, [("f1", "f1.xyz"), ("f2", "f2.xyz")]) + + assert m["fake_user"].pages_used == 2 + + +# =================================================================== +# D) Dropbox smoke tests — verify page limit wiring +# =================================================================== + + +def _make_dropbox_file(file_path: str, name: str, size: int = 80 * 1024) -> dict: + return { + "id": f"id:{file_path}", + "name": name, + ".tag": "file", + "path_lower": file_path, + "size": str(size), + "server_modified": "2026-01-01T00:00:00Z", + "content_hash": f"hash_{name}", + } + + +@pytest.fixture +def dropbox_selected_mocks(monkeypatch): + import app.tasks.connector_indexers.dropbox_indexer as _mod + + session, fake_user = _make_page_limit_session(0, 100) + + get_file_results: dict[str, tuple[dict | None, str | None]] = {} + + async def _fake_get_file(client, file_path): + return get_file_results.get(file_path, (None, f"Not found: {file_path}")) + + monkeypatch.setattr(_mod, "get_file_by_path", _fake_get_file) + monkeypatch.setattr( + _mod, "_should_skip_file", AsyncMock(return_value=(False, None)) + ) + + download_and_index_mock = AsyncMock(return_value=(0, 0)) + monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock) + + pipeline_mock = MagicMock() + pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0) + monkeypatch.setattr( + _mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock) + ) + + return { + "session": session, + "fake_user": fake_user, + "get_file_results": get_file_results, + "download_and_index_mock": download_and_index_mock, + } + + +async def _run_dropbox_selected(mocks, file_paths): + from app.tasks.connector_indexers.dropbox_indexer import _index_selected_files + + return await _index_selected_files( + MagicMock(), mocks["session"], file_paths, + connector_id=_CONNECTOR_ID, search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, enable_summary=True, + ) + + +async def test_dropbox_over_quota_rejected(dropbox_selected_mocks): + """Dropbox: files exceeding quota produce errors, not downloads.""" + m = dropbox_selected_mocks + m["fake_user"].pages_used = 99 + m["fake_user"].pages_limit = 100 + + m["get_file_results"]["/huge.pdf"] = ( + _make_dropbox_file("/huge.pdf", "huge.pdf", size=500 * 1024), None, + ) + + indexed, _skipped, errors = await _run_dropbox_selected( + m, [("/huge.pdf", "huge.pdf")] + ) + + assert indexed == 0 + assert len(errors) == 1 + assert "page limit" in errors[0].lower() + + +async def test_dropbox_deducts_after_success(dropbox_selected_mocks): + """Dropbox: pages_used increases after successful indexing.""" + m = dropbox_selected_mocks + m["fake_user"].pages_used = 0 + m["fake_user"].pages_limit = 100 + + for name in ("f1.xyz", "f2.xyz"): + path = f"/{name}" + m["get_file_results"][path] = ( + _make_dropbox_file(path, name, size=80 * 1024), None, + ) + m["download_and_index_mock"].return_value = (2, 0) + + await _run_dropbox_selected(m, [("/f1.xyz", "f1.xyz"), ("/f2.xyz", "f2.xyz")]) + + assert m["fake_user"].pages_used == 2 From a2b354104631a17c5717cac53837e022843a58bf Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 4 Apr 2026 03:11:56 +0530 Subject: [PATCH 2/6] chore: ran linting --- .../app/services/page_limit_service.py | 96 +++++++++++++++---- .../app/tasks/connector_indexers/base.py | 1 + .../local_folder_indexer.py | 1 + .../integration/document_upload/conftest.py | 15 +-- .../test_local_folder_pipeline.py | 2 +- .../connector_indexers/test_page_limits.py | 84 +++++++++++----- 6 files changed, 150 insertions(+), 49 deletions(-) diff --git a/surfsense_backend/app/services/page_limit_service.py b/surfsense_backend/app/services/page_limit_service.py index ea22067be..47fe07fc6 100644 --- a/surfsense_backend/app/services/page_limit_service.py +++ b/surfsense_backend/app/services/page_limit_service.py @@ -257,28 +257,83 @@ class PageLimitService: return max(1, file_size // (100 * 1024)) if file_ext in { - ".doc", ".docx", ".docm", ".dot", ".dotm", - ".odt", ".ott", ".sxw", ".stw", ".uot", - ".rtf", ".pages", ".wpd", ".wps", - ".abw", ".zabw", ".cwk", ".hwp", ".lwp", - ".mcw", ".mw", ".sdw", ".vor", + ".doc", + ".docx", + ".docm", + ".dot", + ".dotm", + ".odt", + ".ott", + ".sxw", + ".stw", + ".uot", + ".rtf", + ".pages", + ".wpd", + ".wps", + ".abw", + ".zabw", + ".cwk", + ".hwp", + ".lwp", + ".mcw", + ".mw", + ".sdw", + ".vor", }: return max(1, file_size // (50 * 1024)) if file_ext in { - ".ppt", ".pptx", ".pptm", ".pot", ".potx", - ".odp", ".otp", ".sxi", ".sti", ".uop", - ".key", ".sda", ".sdd", ".sdp", + ".ppt", + ".pptx", + ".pptm", + ".pot", + ".potx", + ".odp", + ".otp", + ".sxi", + ".sti", + ".uop", + ".key", + ".sda", + ".sdd", + ".sdp", }: return max(1, file_size // (200 * 1024)) if file_ext in { - ".xls", ".xlsx", ".xlsm", ".xlsb", ".xlw", ".xlr", - ".ods", ".ots", ".fods", ".numbers", - ".123", ".wk1", ".wk2", ".wk3", ".wk4", ".wks", - ".wb1", ".wb2", ".wb3", ".wq1", ".wq2", - ".csv", ".tsv", ".slk", ".sylk", ".dif", ".dbf", - ".prn", ".qpw", ".602", ".et", ".eth", + ".xls", + ".xlsx", + ".xlsm", + ".xlsb", + ".xlw", + ".xlr", + ".ods", + ".ots", + ".fods", + ".numbers", + ".123", + ".wk1", + ".wk2", + ".wk3", + ".wk4", + ".wks", + ".wb1", + ".wb2", + ".wb3", + ".wq1", + ".wq2", + ".csv", + ".tsv", + ".slk", + ".sylk", + ".dif", + ".dbf", + ".prn", + ".qpw", + ".602", + ".et", + ".eth", }: return max(1, file_size // (100 * 1024)) @@ -289,8 +344,17 @@ class PageLimitService: return max(1, file_size // 3000) if file_ext in { - ".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", - ".webp", ".svg", ".cgm", ".odg", ".pbd", + ".jpg", + ".jpeg", + ".png", + ".gif", + ".bmp", + ".tiff", + ".webp", + ".svg", + ".cgm", + ".odg", + ".pbd", }: return 1 diff --git a/surfsense_backend/app/tasks/connector_indexers/base.py b/surfsense_backend/app/tasks/connector_indexers/base.py index 6b4bed4b5..ffc8ab72e 100644 --- a/surfsense_backend/app/tasks/connector_indexers/base.py +++ b/surfsense_backend/app/tasks/connector_indexers/base.py @@ -4,6 +4,7 @@ Base functionality and shared imports for connector indexers. import logging from datetime import UTC, datetime, timedelta + from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select diff --git a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py index acfbce0bf..fa50e86d3 100644 --- a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py @@ -205,6 +205,7 @@ def _compute_final_pages( actual = page_limit_service.estimate_pages_from_content_length(content_length) return max(estimated_pages, actual) + DEFAULT_EXCLUDE_PATTERNS = [ ".git", "node_modules", diff --git a/surfsense_backend/tests/integration/document_upload/conftest.py b/surfsense_backend/tests/integration/document_upload/conftest.py index 6640fefdf..41c379e58 100644 --- a/surfsense_backend/tests/integration/document_upload/conftest.py +++ b/surfsense_backend/tests/integration/document_upload/conftest.py @@ -302,6 +302,9 @@ def _mock_redis_heartbeat(monkeypatch): ) +_MOCK_ETL_MARKDOWN = "# Mocked Document\n\nThis is mocked ETL content." + + @pytest.fixture(autouse=True) def _mock_etl_parsing(monkeypatch): """Mock ETL parsing services — LlamaParse and Docling are external boundaries. @@ -310,8 +313,6 @@ def _mock_etl_parsing(monkeypatch): the actual services would, so tests covering failure paths keep working. """ - _MOCK_MARKDOWN = "# Mocked Document\n\nThis is mocked ETL content." - def _reject_empty(file_path: str) -> None: if os.path.getsize(file_path) == 0: raise RuntimeError(f"Cannot parse empty file: {file_path}") @@ -324,7 +325,7 @@ def _mock_etl_parsing(monkeypatch): class _FakeLlamaParseResult: async def aget_markdown_documents(self, *, split_by_page=False): - return [_FakeMarkdownDoc(_MOCK_MARKDOWN)] + return [_FakeMarkdownDoc(_MOCK_ETL_MARKDOWN)] async def _fake_llamacloud_parse(**kwargs): _reject_empty(kwargs["file_path"]) @@ -339,7 +340,7 @@ def _mock_etl_parsing(monkeypatch): async def _fake_docling_parse(file_path: str, filename: str): _reject_empty(file_path) - return _MOCK_MARKDOWN + return _MOCK_ETL_MARKDOWN monkeypatch.setattr( "app.tasks.document_processors.file_processors.parse_with_docling", @@ -347,10 +348,12 @@ def _mock_etl_parsing(monkeypatch): ) class _FakeDoclingResult: - class document: + class Document: @staticmethod def export_to_markdown(): - return _MOCK_MARKDOWN + return _MOCK_ETL_MARKDOWN + + document = Document() class _FakeDocumentConverter: def convert(self, file_path): diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py index 4d9bda7ee..000f43aa8 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_local_folder_pipeline.py @@ -1015,7 +1015,7 @@ class TestPageLimits: (tmp_path / "note.md").write_text("# Hello World\n\nContent here.") - count, _skipped, _root_folder_id, err = await index_local_folder( + count, _skipped, _root_folder_id, _err = await index_local_folder( session=db_session, search_space_id=db_search_space.id, user_id=str(db_user.id), diff --git a/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py b/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py index 1c93965f3..b31a9557f 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_page_limits.py @@ -58,10 +58,14 @@ class TestEstimatePagesFromMetadata: assert PageLimitService.estimate_pages_from_metadata(ext, 5_000_000) == 1 def test_audio_uses_1mb_per_page(self): - assert PageLimitService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3 + assert ( + PageLimitService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3 + ) def test_video_uses_5mb_per_page(self): - assert PageLimitService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3 + assert ( + PageLimitService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3 + ) def test_unknown_ext_uses_80kb_per_page(self): assert PageLimitService.estimate_pages_from_metadata(".xyz", 160 * 1024) == 2 @@ -189,7 +193,8 @@ async def test_gdrive_files_within_quota_are_downloaded(gdrive_selected_mocks): for fid in ("f1", "f2", "f3"): m["get_file_results"][fid] = ( - _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), + None, ) m["download_and_index_mock"].return_value = (3, 0) @@ -210,7 +215,8 @@ async def test_gdrive_files_exceeding_quota_rejected(gdrive_selected_mocks): m["fake_user"].pages_limit = 100 m["get_file_results"]["big"] = ( - _make_gdrive_file("big", "huge.pdf", size=500 * 1024), None, + _make_gdrive_file("big", "huge.pdf", size=500 * 1024), + None, ) indexed, _skipped, errors = await _run_gdrive_selected(m, [("big", "huge.pdf")]) @@ -228,7 +234,8 @@ async def test_gdrive_quota_mix_partial_indexing(gdrive_selected_mocks): for fid in ("f1", "f2", "f3"): m["get_file_results"][fid] = ( - _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), + None, ) m["download_and_index_mock"].return_value = (2, 0) @@ -250,7 +257,8 @@ async def test_gdrive_proportional_page_deduction(gdrive_selected_mocks): for fid in ("f1", "f2", "f3", "f4"): m["get_file_results"][fid] = ( - _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), + None, ) m["download_and_index_mock"].return_value = (2, 2) @@ -269,7 +277,8 @@ async def test_gdrive_no_deduction_when_nothing_indexed(gdrive_selected_mocks): m["fake_user"].pages_limit = 100 m["get_file_results"]["f1"] = ( - _make_gdrive_file("f1", "f1.xyz", size=80 * 1024), None, + _make_gdrive_file("f1", "f1.xyz", size=80 * 1024), + None, ) m["download_and_index_mock"].return_value = (0, 1) @@ -286,7 +295,8 @@ async def test_gdrive_zero_quota_rejects_all(gdrive_selected_mocks): for fid in ("f1", "f2"): m["get_file_results"][fid] = ( - _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + _make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024), + None, ) indexed, _skipped, errors = await _run_gdrive_selected( @@ -367,7 +377,8 @@ async def test_gdrive_full_scan_skips_over_quota(gdrive_full_scan_mocks, monkeyp _make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(5) ] monkeypatch.setattr( - m["mod"], "get_files_in_folder", + m["mod"], + "get_files_in_folder", AsyncMock(return_value=(page_files, None, None)), ) m["download_mock"].return_value = ([], 0) @@ -391,7 +402,8 @@ async def test_gdrive_full_scan_deducts_after_indexing( _make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(3) ] monkeypatch.setattr( - m["mod"], "get_files_in_folder", + m["mod"], + "get_files_in_folder", AsyncMock(return_value=(page_files, None, None)), ) mock_docs = [MagicMock() for _ in range(3)] @@ -421,7 +433,8 @@ async def test_gdrive_delta_sync_skips_over_quota(monkeypatch): for i in range(5) ] monkeypatch.setattr( - _mod, "fetch_all_changes", + _mod, + "fetch_all_changes", AsyncMock(return_value=(changes, "new-token", None)), ) monkeypatch.setattr(_mod, "categorize_change", lambda change: "modified") @@ -447,11 +460,18 @@ async def test_gdrive_delta_sync_skips_over_quota(monkeypatch): mock_task_logger.log_task_progress = AsyncMock() _indexed, skipped = await _mod._index_with_delta_sync( - MagicMock(), session, MagicMock(), - _CONNECTOR_ID, _SEARCH_SPACE_ID, _USER_ID, - "folder-root", "start-token", - mock_task_logger, MagicMock(), - max_files=500, enable_summary=True, + MagicMock(), + session, + MagicMock(), + _CONNECTOR_ID, + _SEARCH_SPACE_ID, + _USER_ID, + "folder-root", + "start-token", + mock_task_logger, + MagicMock(), + max_files=500, + enable_summary=True, ) call_files = download_mock.call_args[0][1] @@ -511,9 +531,13 @@ async def _run_onedrive_selected(mocks, file_ids): from app.tasks.connector_indexers.onedrive_indexer import _index_selected_files return await _index_selected_files( - MagicMock(), mocks["session"], file_ids, - connector_id=_CONNECTOR_ID, search_space_id=_SEARCH_SPACE_ID, - user_id=_USER_ID, enable_summary=True, + MagicMock(), + mocks["session"], + file_ids, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, ) @@ -524,7 +548,8 @@ async def test_onedrive_over_quota_rejected(onedrive_selected_mocks): m["fake_user"].pages_limit = 100 m["get_file_results"]["big"] = ( - _make_onedrive_file("big", "huge.pdf", size=500 * 1024), None, + _make_onedrive_file("big", "huge.pdf", size=500 * 1024), + None, ) indexed, _skipped, errors = await _run_onedrive_selected(m, [("big", "huge.pdf")]) @@ -542,7 +567,8 @@ async def test_onedrive_deducts_after_success(onedrive_selected_mocks): for fid in ("f1", "f2"): m["get_file_results"][fid] = ( - _make_onedrive_file(fid, f"{fid}.xyz", size=80 * 1024), None, + _make_onedrive_file(fid, f"{fid}.xyz", size=80 * 1024), + None, ) m["download_and_index_mock"].return_value = (2, 0) @@ -605,9 +631,13 @@ async def _run_dropbox_selected(mocks, file_paths): from app.tasks.connector_indexers.dropbox_indexer import _index_selected_files return await _index_selected_files( - MagicMock(), mocks["session"], file_paths, - connector_id=_CONNECTOR_ID, search_space_id=_SEARCH_SPACE_ID, - user_id=_USER_ID, enable_summary=True, + MagicMock(), + mocks["session"], + file_paths, + connector_id=_CONNECTOR_ID, + search_space_id=_SEARCH_SPACE_ID, + user_id=_USER_ID, + enable_summary=True, ) @@ -618,7 +648,8 @@ async def test_dropbox_over_quota_rejected(dropbox_selected_mocks): m["fake_user"].pages_limit = 100 m["get_file_results"]["/huge.pdf"] = ( - _make_dropbox_file("/huge.pdf", "huge.pdf", size=500 * 1024), None, + _make_dropbox_file("/huge.pdf", "huge.pdf", size=500 * 1024), + None, ) indexed, _skipped, errors = await _run_dropbox_selected( @@ -639,7 +670,8 @@ async def test_dropbox_deducts_after_success(dropbox_selected_mocks): for name in ("f1.xyz", "f2.xyz"): path = f"/{name}" m["get_file_results"][path] = ( - _make_dropbox_file(path, name, size=80 * 1024), None, + _make_dropbox_file(path, name, size=80 * 1024), + None, ) m["download_and_index_mock"].return_value = (2, 0) From 09008c8f1a6ed963bcac17957ce14c8b3bc569e3 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 4 Apr 2026 03:26:22 +0530 Subject: [PATCH 3/6] refactor: remove redundant authenticatedFetch calls in editor panel components --- surfsense_web/components/editor-panel/editor-panel.tsx | 3 --- surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx | 3 --- 2 files changed, 6 deletions(-) diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 248fe68eb..c307b3cea 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -96,9 +96,6 @@ export function EditorPanelContent({ } try { - const response = await authenticatedFetch( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content`, - { method: "GET", signal: controller.signal } const url = new URL( `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); diff --git a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx index d2ce3cc64..97c5b7cd9 100644 --- a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx +++ b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx @@ -81,9 +81,6 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen } try { - const response = await authenticatedFetch( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content`, - { method: "GET", signal: controller.signal } const url = new URL( `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/editor-content` ); From 8e6b1c77eafbbb54c69ccefaf26aa017cb8e2e50 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 4 Apr 2026 03:35:34 +0530 Subject: [PATCH 4/6] feat: implement PKCE support in native Google OAuth flows - Added `generate_code_verifier` function to create a PKCE code verifier for enhanced security. - Updated Google Calendar, Drive, and Gmail connector routes to utilize the PKCE code verifier during OAuth authorization. - Modified state management to include the code verifier for secure state generation and validation. --- .../google_calendar_add_connector_route.py | 18 ++++++++++---- .../google_drive_add_connector_route.py | 24 +++++++++++++------ .../google_gmail_add_connector_route.py | 18 ++++++++++---- surfsense_backend/app/utils/oauth_security.py | 10 ++++++++ 4 files changed, 55 insertions(+), 15 deletions(-) diff --git a/surfsense_backend/app/routes/google_calendar_add_connector_route.py b/surfsense_backend/app/routes/google_calendar_add_connector_route.py index 9a2308bec..725f8decc 100644 --- a/surfsense_backend/app/routes/google_calendar_add_connector_route.py +++ b/surfsense_backend/app/routes/google_calendar_add_connector_route.py @@ -28,7 +28,7 @@ from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, ) -from app.utils.oauth_security import OAuthStateManager, TokenEncryption +from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier logger = logging.getLogger(__name__) @@ -96,9 +96,14 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us flow = get_google_flow() - # Generate secure state parameter with HMAC signature + code_verifier = generate_code_verifier() + flow.code_verifier = code_verifier + + # Generate secure state parameter with HMAC signature (includes PKCE code_verifier) state_manager = get_state_manager() - state_encoded = state_manager.generate_secure_state(space_id, user.id) + state_encoded = state_manager.generate_secure_state( + space_id, user.id, code_verifier=code_verifier + ) auth_url, _ = flow.authorization_url( access_type="offline", @@ -146,8 +151,11 @@ async def reauth_calendar( flow = get_google_flow() + code_verifier = generate_code_verifier() + flow.code_verifier = code_verifier + state_manager = get_state_manager() - extra: dict = {"connector_id": connector_id} + extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier} if return_url and return_url.startswith("/"): extra["return_url"] = return_url state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra) @@ -225,6 +233,7 @@ async def calendar_callback( user_id = UUID(data["user_id"]) space_id = data["space_id"] + code_verifier = data.get("code_verifier") # Validate redirect URI (security: ensure it matches configured value) if not config.GOOGLE_CALENDAR_REDIRECT_URI: @@ -233,6 +242,7 @@ async def calendar_callback( ) flow = get_google_flow() + flow.code_verifier = code_verifier flow.fetch_token(code=code) creds = flow.credentials diff --git a/surfsense_backend/app/routes/google_drive_add_connector_route.py b/surfsense_backend/app/routes/google_drive_add_connector_route.py index 1c9391610..921f84af9 100644 --- a/surfsense_backend/app/routes/google_drive_add_connector_route.py +++ b/surfsense_backend/app/routes/google_drive_add_connector_route.py @@ -41,7 +41,7 @@ from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, ) -from app.utils.oauth_security import OAuthStateManager, TokenEncryption +from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier # Relax token scope validation for Google OAuth os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" @@ -127,14 +127,19 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user) flow = get_google_flow() - # Generate secure state parameter with HMAC signature + code_verifier = generate_code_verifier() + flow.code_verifier = code_verifier + + # Generate secure state parameter with HMAC signature (includes PKCE code_verifier) state_manager = get_state_manager() - state_encoded = state_manager.generate_secure_state(space_id, user.id) + state_encoded = state_manager.generate_secure_state( + space_id, user.id, code_verifier=code_verifier + ) # Generate authorization URL auth_url, _ = flow.authorization_url( - access_type="offline", # Get refresh token - prompt="consent", # Force consent screen to get refresh token + access_type="offline", + prompt="consent", include_granted_scopes="true", state=state_encoded, ) @@ -193,8 +198,11 @@ async def reauth_drive( flow = get_google_flow() + code_verifier = generate_code_verifier() + flow.code_verifier = code_verifier + state_manager = get_state_manager() - extra: dict = {"connector_id": connector_id} + extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier} if return_url and return_url.startswith("/"): extra["return_url"] = return_url state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra) @@ -285,6 +293,7 @@ async def drive_callback( space_id = data["space_id"] reauth_connector_id = data.get("connector_id") reauth_return_url = data.get("return_url") + code_verifier = data.get("code_verifier") logger.info( f"Processing Google Drive callback for user {user_id}, space {space_id}" @@ -296,8 +305,9 @@ async def drive_callback( status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured" ) - # Exchange authorization code for tokens + # Exchange authorization code for tokens (restore PKCE code_verifier from state) flow = get_google_flow() + flow.code_verifier = code_verifier flow.fetch_token(code=code) creds = flow.credentials diff --git a/surfsense_backend/app/routes/google_gmail_add_connector_route.py b/surfsense_backend/app/routes/google_gmail_add_connector_route.py index 750a64819..9fe0c0de6 100644 --- a/surfsense_backend/app/routes/google_gmail_add_connector_route.py +++ b/surfsense_backend/app/routes/google_gmail_add_connector_route.py @@ -28,7 +28,7 @@ from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, ) -from app.utils.oauth_security import OAuthStateManager, TokenEncryption +from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier logger = logging.getLogger(__name__) @@ -109,9 +109,14 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user) flow = get_google_flow() - # Generate secure state parameter with HMAC signature + code_verifier = generate_code_verifier() + flow.code_verifier = code_verifier + + # Generate secure state parameter with HMAC signature (includes PKCE code_verifier) state_manager = get_state_manager() - state_encoded = state_manager.generate_secure_state(space_id, user.id) + state_encoded = state_manager.generate_secure_state( + space_id, user.id, code_verifier=code_verifier + ) auth_url, _ = flow.authorization_url( access_type="offline", @@ -164,8 +169,11 @@ async def reauth_gmail( flow = get_google_flow() + code_verifier = generate_code_verifier() + flow.code_verifier = code_verifier + state_manager = get_state_manager() - extra: dict = {"connector_id": connector_id} + extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier} if return_url and return_url.startswith("/"): extra["return_url"] = return_url state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra) @@ -256,6 +264,7 @@ async def gmail_callback( user_id = UUID(data["user_id"]) space_id = data["space_id"] + code_verifier = data.get("code_verifier") # Validate redirect URI (security: ensure it matches configured value) if not config.GOOGLE_GMAIL_REDIRECT_URI: @@ -264,6 +273,7 @@ async def gmail_callback( ) flow = get_google_flow() + flow.code_verifier = code_verifier flow.fetch_token(code=code) creds = flow.credentials diff --git a/surfsense_backend/app/utils/oauth_security.py b/surfsense_backend/app/utils/oauth_security.py index 5135cdef4..0ad9d3bd9 100644 --- a/surfsense_backend/app/utils/oauth_security.py +++ b/surfsense_backend/app/utils/oauth_security.py @@ -11,6 +11,8 @@ import hmac import json import logging import time +from random import SystemRandom +from string import ascii_letters, digits from uuid import UUID from cryptography.fernet import Fernet @@ -18,6 +20,14 @@ from fastapi import HTTPException logger = logging.getLogger(__name__) +_PKCE_CHARS = ascii_letters + digits + "-._~" +_PKCE_RNG = SystemRandom() + + +def generate_code_verifier(length: int = 128) -> str: + """Generate a PKCE code_verifier (RFC 7636, 43-128 unreserved chars).""" + return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length)) + class OAuthStateManager: """Manages secure OAuth state parameters with HMAC signatures.""" From e814540727bbc457bef59e5a6b63f3a2aec2f957 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 4 Apr 2026 03:36:54 +0530 Subject: [PATCH 5/6] refactor: move PKCE pair generatio for airtable - Removed the `generate_pkce_pair` function from `airtable_add_connector_route.py` and relocated it to `oauth_security.py` for better organization. - Updated imports in `airtable_add_connector_route.py` to reflect the new location of the PKCE generation function. --- .../routes/airtable_add_connector_route.py | 26 +------------------ surfsense_backend/app/utils/oauth_security.py | 11 ++++++++ 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/surfsense_backend/app/routes/airtable_add_connector_route.py b/surfsense_backend/app/routes/airtable_add_connector_route.py index fe359d2f3..d2d25d006 100644 --- a/surfsense_backend/app/routes/airtable_add_connector_route.py +++ b/surfsense_backend/app/routes/airtable_add_connector_route.py @@ -1,7 +1,5 @@ import base64 -import hashlib import logging -import secrets from datetime import UTC, datetime, timedelta from uuid import UUID @@ -26,7 +24,7 @@ from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, ) -from app.utils.oauth_security import OAuthStateManager, TokenEncryption +from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair logger = logging.getLogger(__name__) @@ -75,28 +73,6 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str: return f"Basic {b64}" -def generate_pkce_pair() -> tuple[str, str]: - """ - Generate PKCE code verifier and code challenge. - - Returns: - Tuple of (code_verifier, code_challenge) - """ - # Generate code verifier (43-128 characters) - code_verifier = ( - base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") - ) - - # Generate code challenge (SHA256 hash of verifier, base64url encoded) - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()) - .decode("utf-8") - .rstrip("=") - ) - - return code_verifier, code_challenge - - @router.get("/auth/airtable/connector/add") async def connect_airtable(space_id: int, user: User = Depends(current_active_user)): """ diff --git a/surfsense_backend/app/utils/oauth_security.py b/surfsense_backend/app/utils/oauth_security.py index 0ad9d3bd9..c39b1e9b1 100644 --- a/surfsense_backend/app/utils/oauth_security.py +++ b/surfsense_backend/app/utils/oauth_security.py @@ -29,6 +29,17 @@ def generate_code_verifier(length: int = 128) -> str: return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length)) +def generate_pkce_pair(length: int = 128) -> tuple[str, str]: + """Generate a PKCE code_verifier and its S256 code_challenge.""" + verifier = generate_code_verifier(length) + challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + return verifier, challenge + + class OAuthStateManager: """Manages secure OAuth state parameters with HMAC signatures.""" From 82d4d3e27234fb876701a35f22b15c99e5608b15 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 4 Apr 2026 03:37:33 +0530 Subject: [PATCH 6/6] chore: ran linting --- .../app/routes/airtable_add_connector_route.py | 6 +++++- .../app/routes/google_calendar_add_connector_route.py | 6 +++++- .../app/routes/google_drive_add_connector_route.py | 6 +++++- .../app/routes/google_gmail_add_connector_route.py | 6 +++++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/surfsense_backend/app/routes/airtable_add_connector_route.py b/surfsense_backend/app/routes/airtable_add_connector_route.py index d2d25d006..1e0b1eb5d 100644 --- a/surfsense_backend/app/routes/airtable_add_connector_route.py +++ b/surfsense_backend/app/routes/airtable_add_connector_route.py @@ -24,7 +24,11 @@ from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, ) -from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair +from app.utils.oauth_security import ( + OAuthStateManager, + TokenEncryption, + generate_pkce_pair, +) logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/routes/google_calendar_add_connector_route.py b/surfsense_backend/app/routes/google_calendar_add_connector_route.py index 725f8decc..d7ccf62ca 100644 --- a/surfsense_backend/app/routes/google_calendar_add_connector_route.py +++ b/surfsense_backend/app/routes/google_calendar_add_connector_route.py @@ -28,7 +28,11 @@ from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, ) -from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier +from app.utils.oauth_security import ( + OAuthStateManager, + TokenEncryption, + generate_code_verifier, +) logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/routes/google_drive_add_connector_route.py b/surfsense_backend/app/routes/google_drive_add_connector_route.py index 921f84af9..8706326b7 100644 --- a/surfsense_backend/app/routes/google_drive_add_connector_route.py +++ b/surfsense_backend/app/routes/google_drive_add_connector_route.py @@ -41,7 +41,11 @@ from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, ) -from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier +from app.utils.oauth_security import ( + OAuthStateManager, + TokenEncryption, + generate_code_verifier, +) # Relax token scope validation for Google OAuth os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" diff --git a/surfsense_backend/app/routes/google_gmail_add_connector_route.py b/surfsense_backend/app/routes/google_gmail_add_connector_route.py index 9fe0c0de6..dd8feb1c7 100644 --- a/surfsense_backend/app/routes/google_gmail_add_connector_route.py +++ b/surfsense_backend/app/routes/google_gmail_add_connector_route.py @@ -28,7 +28,11 @@ from app.utils.connector_naming import ( check_duplicate_connector, generate_unique_connector_name, ) -from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_code_verifier +from app.utils.oauth_security import ( + OAuthStateManager, + TokenEncryption, + generate_code_verifier, +) logger = logging.getLogger(__name__)