mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-20 21:18:13 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/unified-etl-pipeline
This commit is contained in:
commit
63a75052ca
76 changed files with 3041 additions and 376 deletions
|
|
@ -0,0 +1,39 @@
|
|||
"""119_add_vision_llm_id_to_search_spaces
|
||||
|
||||
Revision ID: 119
|
||||
Revises: 118
|
||||
|
||||
Adds vision_llm_id column to search_spaces for vision/screenshot analysis
|
||||
LLM role assignment. Defaults to 0 (Auto mode), same convention as
|
||||
agent_llm_id and document_summary_llm_id.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "119"
|
||||
down_revision: str | None = "118"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_columns = [
|
||||
col["name"] for col in sa.inspect(conn).get_columns("searchspaces")
|
||||
]
|
||||
|
||||
if "vision_llm_id" not in existing_columns:
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("vision_llm_id", sa.Integer(), nullable=True, server_default="0"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("searchspaces", "vision_llm_id")
|
||||
|
|
@ -1351,6 +1351,9 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
image_generation_config_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For image generation, defaults to Auto mode
|
||||
vision_llm_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from fastapi import APIRouter
|
|||
from .airtable_add_connector_route import (
|
||||
router as airtable_add_connector_router,
|
||||
)
|
||||
from .autocomplete_routes import router as autocomplete_router
|
||||
from .chat_comments_routes import router as chat_comments_router
|
||||
from .circleback_webhook_route import router as circleback_webhook_router
|
||||
from .clickup_add_connector_route import router as clickup_add_connector_router
|
||||
|
|
@ -95,3 +96,4 @@ router.include_router(incentive_tasks_router) # Incentive tasks for earning fre
|
|||
router.include_router(stripe_router) # Stripe checkout for additional page packs
|
||||
router.include_router(youtube_router) # YouTube playlist resolution
|
||||
router.include_router(prompts_router)
|
||||
router.include_router(autocomplete_router) # Lightweight autocomplete with KB context
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -26,7 +24,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_pkce_pair,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -75,28 +77,6 @@ def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
|||
return f"Basic {b64}"
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
"""
|
||||
Generate PKCE code verifier and code challenge.
|
||||
|
||||
Returns:
|
||||
Tuple of (code_verifier, code_challenge)
|
||||
"""
|
||||
# Generate code verifier (43-128 characters)
|
||||
code_verifier = (
|
||||
base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
|
||||
)
|
||||
|
||||
# Generate code challenge (SHA256 hash of verifier, base64url encoded)
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
|
||||
.decode("utf-8")
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
@router.get("/auth/airtable/connector/add")
|
||||
async def connect_airtable(space_id: int, user: User = Depends(current_active_user)):
|
||||
"""
|
||||
|
|
|
|||
42
surfsense_backend/app/routes/autocomplete_routes.py
Normal file
42
surfsense_backend/app/routes/autocomplete_routes.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, get_async_session
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.services.vision_autocomplete_service import stream_vision_autocomplete
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
router = APIRouter(prefix="/autocomplete", tags=["autocomplete"])
|
||||
|
||||
MAX_SCREENSHOT_SIZE = 20 * 1024 * 1024 # 20 MB base64 ceiling
|
||||
|
||||
|
||||
class VisionAutocompleteRequest(BaseModel):
|
||||
screenshot: str = Field(..., max_length=MAX_SCREENSHOT_SIZE)
|
||||
search_space_id: int
|
||||
app_name: str = ""
|
||||
window_title: str = ""
|
||||
|
||||
|
||||
@router.post("/vision/stream")
|
||||
async def vision_autocomplete_stream(
|
||||
body: VisionAutocompleteRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
await check_search_space_access(session, user, body.search_space_id)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_vision_autocomplete(
|
||||
body.screenshot, body.search_space_id, session,
|
||||
app_name=body.app_name, window_title=body.window_title,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
**VercelStreamingService.get_response_headers(),
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
|
@ -28,7 +28,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -96,9 +100,14 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -146,8 +155,11 @@ async def reauth_calendar(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -225,6 +237,7 @@ async def calendar_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_CALENDAR_REDIRECT_URI:
|
||||
|
|
@ -233,6 +246,7 @@ async def calendar_callback(
|
|||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -41,7 +41,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
# Relax token scope validation for Google OAuth
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
|
|
@ -127,14 +131,19 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
# Generate authorization URL
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline", # Get refresh token
|
||||
prompt="consent", # Force consent screen to get refresh token
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
|
@ -193,8 +202,11 @@ async def reauth_drive(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -285,6 +297,7 @@ async def drive_callback(
|
|||
space_id = data["space_id"]
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
logger.info(
|
||||
f"Processing Google Drive callback for user {user_id}, space {space_id}"
|
||||
|
|
@ -296,8 +309,9 @@ async def drive_callback(
|
|||
status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
# Exchange authorization code for tokens (restore PKCE code_verifier from state)
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -28,7 +28,11 @@ from app.utils.connector_naming import (
|
|||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
from app.utils.oauth_security import (
|
||||
OAuthStateManager,
|
||||
TokenEncryption,
|
||||
generate_code_verifier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -109,9 +113,14 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
# Generate secure state parameter with HMAC signature (includes PKCE code_verifier)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -164,8 +173,11 @@ async def reauth_gmail(
|
|||
|
||||
flow = get_google_flow()
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
flow.code_verifier = code_verifier
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
extra: dict = {"connector_id": connector_id, "code_verifier": code_verifier}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
|
@ -256,6 +268,7 @@ async def gmail_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_GMAIL_REDIRECT_URI:
|
||||
|
|
@ -264,6 +277,7 @@ async def gmail_callback(
|
|||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.code_verifier = code_verifier
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
|
|
|
|||
|
|
@ -522,14 +522,17 @@ async def get_llm_preferences(
|
|||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm = await _get_llm_config_by_id(session, search_space.vision_llm_id)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_id=search_space.vision_llm_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm=vision_llm,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -589,14 +592,17 @@ async def update_llm_preferences(
|
|||
image_generation_config = await _get_image_gen_config_by_id(
|
||||
session, search_space.image_generation_config_id
|
||||
)
|
||||
vision_llm = await _get_llm_config_by_id(session, search_space.vision_llm_id)
|
||||
|
||||
return LLMPreferencesRead(
|
||||
agent_llm_id=search_space.agent_llm_id,
|
||||
document_summary_llm_id=search_space.document_summary_llm_id,
|
||||
image_generation_config_id=search_space.image_generation_config_id,
|
||||
vision_llm_id=search_space.vision_llm_id,
|
||||
agent_llm=agent_llm,
|
||||
document_summary_llm=document_summary_llm,
|
||||
image_generation_config=image_generation_config,
|
||||
vision_llm=vision_llm,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -182,6 +182,9 @@ class LLMPreferencesRead(BaseModel):
|
|||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for vision/screenshot analysis"
|
||||
)
|
||||
agent_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for agent LLM"
|
||||
)
|
||||
|
|
@ -191,6 +194,9 @@ class LLMPreferencesRead(BaseModel):
|
|||
image_generation_config: dict[str, Any] | None = Field(
|
||||
None, description="Full config for image generation"
|
||||
)
|
||||
vision_llm: dict[str, Any] | None = Field(
|
||||
None, description="Full config for vision LLM"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -207,3 +213,6 @@ class LLMPreferencesUpdate(BaseModel):
|
|||
image_generation_config_id: int | None = Field(
|
||||
None, description="ID of the image generation config to use"
|
||||
)
|
||||
vision_llm_id: int | None = Field(
|
||||
None, description="ID of the LLM config to use for vision/screenshot analysis"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
|
|||
class LLMRole:
|
||||
AGENT = "agent" # For agent/chat operations
|
||||
DOCUMENT_SUMMARY = "document_summary" # For document summarization
|
||||
VISION = "vision" # For vision/screenshot analysis
|
||||
|
||||
|
||||
def get_global_llm_config(llm_config_id: int) -> dict | None:
|
||||
|
|
@ -187,7 +188,7 @@ async def get_search_space_llm_instance(
|
|||
Args:
|
||||
session: Database session
|
||||
search_space_id: Search Space ID
|
||||
role: LLM role ('agent' or 'document_summary')
|
||||
role: LLM role ('agent', 'document_summary', or 'vision')
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM or ChatLiteLLMRouter instance, or None if not found
|
||||
|
|
@ -209,6 +210,8 @@ async def get_search_space_llm_instance(
|
|||
llm_config_id = search_space.agent_llm_id
|
||||
elif role == LLMRole.DOCUMENT_SUMMARY:
|
||||
llm_config_id = search_space.document_summary_llm_id
|
||||
elif role == LLMRole.VISION:
|
||||
llm_config_id = search_space.vision_llm_id
|
||||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
|
@ -405,6 +408,13 @@ async def get_document_summary_llm(
|
|||
)
|
||||
|
||||
|
||||
async def get_vision_llm(
|
||||
session: AsyncSession, search_space_id: int
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's vision LLM instance for screenshot analysis."""
|
||||
return await get_search_space_llm_instance(session, search_space_id, LLMRole.VISION)
|
||||
|
||||
|
||||
# Backward-compatible alias (LLM preferences are now per-search-space, not per-user)
|
||||
async def get_user_long_context_llm(
|
||||
session: AsyncSession,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Service for managing user page limits for ETL services.
|
|||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -223,10 +223,155 @@ class PageLimitService:
|
|||
# Estimate ~2000 characters per page
|
||||
return max(1, content_length // 2000)
|
||||
|
||||
@staticmethod
|
||||
def estimate_pages_from_metadata(
|
||||
file_name_or_ext: str, file_size: int | str | None = None
|
||||
) -> int:
|
||||
"""Size-based page estimation from file name/extension and byte size.
|
||||
|
||||
Pure function — no file I/O, no database access. Used by cloud
|
||||
connectors (which only have API metadata) and as the internal
|
||||
fallback for :meth:`estimate_pages_before_processing`.
|
||||
|
||||
``file_name_or_ext`` can be a full filename (``"report.pdf"``) or
|
||||
a bare extension (``".pdf"``). ``file_size`` may be an int, a
|
||||
stringified int from a cloud API, or *None*.
|
||||
"""
|
||||
if file_size is not None:
|
||||
try:
|
||||
file_size = int(file_size)
|
||||
except (ValueError, TypeError):
|
||||
file_size = 0
|
||||
else:
|
||||
file_size = 0
|
||||
|
||||
if file_size <= 0:
|
||||
return 1
|
||||
|
||||
ext = PurePosixPath(file_name_or_ext).suffix.lower() if file_name_or_ext else ""
|
||||
if not ext and file_name_or_ext.startswith("."):
|
||||
ext = file_name_or_ext.lower()
|
||||
file_ext = ext
|
||||
|
||||
if file_ext == ".pdf":
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".doc",
|
||||
".docx",
|
||||
".docm",
|
||||
".dot",
|
||||
".dotm",
|
||||
".odt",
|
||||
".ott",
|
||||
".sxw",
|
||||
".stw",
|
||||
".uot",
|
||||
".rtf",
|
||||
".pages",
|
||||
".wpd",
|
||||
".wps",
|
||||
".abw",
|
||||
".zabw",
|
||||
".cwk",
|
||||
".hwp",
|
||||
".lwp",
|
||||
".mcw",
|
||||
".mw",
|
||||
".sdw",
|
||||
".vor",
|
||||
}:
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".ppt",
|
||||
".pptx",
|
||||
".pptm",
|
||||
".pot",
|
||||
".potx",
|
||||
".odp",
|
||||
".otp",
|
||||
".sxi",
|
||||
".sti",
|
||||
".uop",
|
||||
".key",
|
||||
".sda",
|
||||
".sdd",
|
||||
".sdp",
|
||||
}:
|
||||
return max(1, file_size // (200 * 1024))
|
||||
|
||||
if file_ext in {
|
||||
".xls",
|
||||
".xlsx",
|
||||
".xlsm",
|
||||
".xlsb",
|
||||
".xlw",
|
||||
".xlr",
|
||||
".ods",
|
||||
".ots",
|
||||
".fods",
|
||||
".numbers",
|
||||
".123",
|
||||
".wk1",
|
||||
".wk2",
|
||||
".wk3",
|
||||
".wk4",
|
||||
".wks",
|
||||
".wb1",
|
||||
".wb2",
|
||||
".wb3",
|
||||
".wq1",
|
||||
".wq2",
|
||||
".csv",
|
||||
".tsv",
|
||||
".slk",
|
||||
".sylk",
|
||||
".dif",
|
||||
".dbf",
|
||||
".prn",
|
||||
".qpw",
|
||||
".602",
|
||||
".et",
|
||||
".eth",
|
||||
}:
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
if file_ext in {".epub"}:
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
if file_ext in {".txt", ".log", ".md", ".markdown", ".htm", ".html", ".xml"}:
|
||||
return max(1, file_size // 3000)
|
||||
|
||||
if file_ext in {
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".gif",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".webp",
|
||||
".svg",
|
||||
".cgm",
|
||||
".odg",
|
||||
".pbd",
|
||||
}:
|
||||
return 1
|
||||
|
||||
if file_ext in {".mp3", ".m4a", ".wav", ".mpga"}:
|
||||
return max(1, file_size // (1024 * 1024))
|
||||
|
||||
if file_ext in {".mp4", ".mpeg", ".webm"}:
|
||||
return max(1, file_size // (5 * 1024 * 1024))
|
||||
|
||||
return max(1, file_size // (80 * 1024))
|
||||
|
||||
def estimate_pages_before_processing(self, file_path: str) -> int:
|
||||
"""
|
||||
Estimate page count from file before processing (to avoid unnecessary API calls).
|
||||
This is called BEFORE sending to ETL services to prevent cost on rejected files.
|
||||
Estimate page count from a local file before processing.
|
||||
|
||||
For PDFs, attempts to read the actual page count via pypdf.
|
||||
For everything else, delegates to :meth:`estimate_pages_from_metadata`.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
|
@ -240,7 +385,6 @@ class PageLimitService:
|
|||
file_ext = Path(file_path).suffix.lower()
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
# PDF files - try to get actual page count
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
import pypdf
|
||||
|
|
@ -249,153 +393,6 @@ class PageLimitService:
|
|||
pdf_reader = pypdf.PdfReader(f)
|
||||
return len(pdf_reader.pages)
|
||||
except Exception:
|
||||
# If PDF reading fails, fall back to size estimation
|
||||
# Typical PDF: ~100KB per page (conservative estimate)
|
||||
return max(1, file_size // (100 * 1024))
|
||||
pass # fall through to size-based estimation
|
||||
|
||||
# Word Processing Documents
|
||||
# Microsoft Word, LibreOffice Writer, WordPerfect, Pages, etc.
|
||||
elif file_ext in [
|
||||
".doc",
|
||||
".docx",
|
||||
".docm",
|
||||
".dot",
|
||||
".dotm", # Microsoft Word
|
||||
".odt",
|
||||
".ott",
|
||||
".sxw",
|
||||
".stw",
|
||||
".uot", # OpenDocument/StarOffice Writer
|
||||
".rtf", # Rich Text Format
|
||||
".pages", # Apple Pages
|
||||
".wpd",
|
||||
".wps", # WordPerfect, Microsoft Works
|
||||
".abw",
|
||||
".zabw", # AbiWord
|
||||
".cwk",
|
||||
".hwp",
|
||||
".lwp",
|
||||
".mcw",
|
||||
".mw",
|
||||
".sdw",
|
||||
".vor", # Other word processors
|
||||
]:
|
||||
# Typical word document: ~50KB per page (conservative)
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
# Presentation Documents
|
||||
# PowerPoint, Impress, Keynote, etc.
|
||||
elif file_ext in [
|
||||
".ppt",
|
||||
".pptx",
|
||||
".pptm",
|
||||
".pot",
|
||||
".potx", # Microsoft PowerPoint
|
||||
".odp",
|
||||
".otp",
|
||||
".sxi",
|
||||
".sti",
|
||||
".uop", # OpenDocument/StarOffice Impress
|
||||
".key", # Apple Keynote
|
||||
".sda",
|
||||
".sdd",
|
||||
".sdp", # StarOffice Draw/Impress
|
||||
]:
|
||||
# Typical presentation: ~200KB per slide (conservative)
|
||||
return max(1, file_size // (200 * 1024))
|
||||
|
||||
# Spreadsheet Documents
|
||||
# Excel, Calc, Numbers, Lotus, etc.
|
||||
elif file_ext in [
|
||||
".xls",
|
||||
".xlsx",
|
||||
".xlsm",
|
||||
".xlsb",
|
||||
".xlw",
|
||||
".xlr", # Microsoft Excel
|
||||
".ods",
|
||||
".ots",
|
||||
".fods", # OpenDocument Spreadsheet
|
||||
".numbers", # Apple Numbers
|
||||
".123",
|
||||
".wk1",
|
||||
".wk2",
|
||||
".wk3",
|
||||
".wk4",
|
||||
".wks", # Lotus 1-2-3
|
||||
".wb1",
|
||||
".wb2",
|
||||
".wb3",
|
||||
".wq1",
|
||||
".wq2", # Quattro Pro
|
||||
".csv",
|
||||
".tsv",
|
||||
".slk",
|
||||
".sylk",
|
||||
".dif",
|
||||
".dbf",
|
||||
".prn",
|
||||
".qpw", # Data formats
|
||||
".602",
|
||||
".et",
|
||||
".eth", # Other spreadsheets
|
||||
]:
|
||||
# Spreadsheets typically have 1 sheet = 1 page for ETL
|
||||
# Conservative: ~100KB per sheet
|
||||
return max(1, file_size // (100 * 1024))
|
||||
|
||||
# E-books
|
||||
elif file_ext in [".epub"]:
|
||||
# E-books vary widely, estimate by size
|
||||
# Typical e-book: ~50KB per page
|
||||
return max(1, file_size // (50 * 1024))
|
||||
|
||||
# Plain Text and Markup Files
|
||||
elif file_ext in [
|
||||
".txt",
|
||||
".log", # Plain text
|
||||
".md",
|
||||
".markdown", # Markdown
|
||||
".htm",
|
||||
".html",
|
||||
".xml", # Markup
|
||||
]:
|
||||
# Plain text: ~3000 bytes per page
|
||||
return max(1, file_size // 3000)
|
||||
|
||||
# Image Files
|
||||
# Each image is typically processed as 1 page
|
||||
elif file_ext in [
|
||||
".jpg",
|
||||
".jpeg", # JPEG
|
||||
".png", # PNG
|
||||
".gif", # GIF
|
||||
".bmp", # Bitmap
|
||||
".tiff", # TIFF
|
||||
".webp", # WebP
|
||||
".svg", # SVG
|
||||
".cgm", # Computer Graphics Metafile
|
||||
".odg",
|
||||
".pbd", # OpenDocument Graphics
|
||||
]:
|
||||
# Each image = 1 page
|
||||
return 1
|
||||
|
||||
# Audio Files (transcription = typically 1 page per minute)
|
||||
# Note: These should be handled by audio transcription flow, not ETL
|
||||
elif file_ext in [".mp3", ".m4a", ".wav", ".mpga"]:
|
||||
# Audio files: estimate based on duration
|
||||
# Fallback: ~1MB per minute of audio, 1 page per minute transcript
|
||||
return max(1, file_size // (1024 * 1024))
|
||||
|
||||
# Video Files (typically not processed for pages, but just in case)
|
||||
elif file_ext in [".mp4", ".mpeg", ".webm"]:
|
||||
# Video files: very rough estimate
|
||||
# Typically wouldn't be page-based, but use conservative estimate
|
||||
return max(1, file_size // (5 * 1024 * 1024))
|
||||
|
||||
# Other/Unknown Document Types
|
||||
else:
|
||||
# Conservative estimate: ~80KB per page
|
||||
# This catches: .sgl, .sxg, .uof, .uos1, .uos2, .web, and any future formats
|
||||
return max(1, file_size // (80 * 1024))
|
||||
return self.estimate_pages_from_metadata(file_ext, file_size)
|
||||
|
|
|
|||
225
surfsense_backend/app/services/vision_autocomplete_service.py
Normal file
225
surfsense_backend/app/services/vision_autocomplete_service.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.services.llm_service import get_vision_llm
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KB_TOP_K = 5
|
||||
KB_MAX_CHARS = 4000
|
||||
|
||||
EXTRACT_QUERY_PROMPT = """Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
|
||||
|
||||
EXTRACT_QUERY_PROMPT_WITH_APP = """The user is currently in the application "{app_name}" with the window titled "{window_title}".
|
||||
|
||||
Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
|
||||
|
||||
VISION_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
|
||||
|
||||
You will receive a screenshot of the user's screen. Your job:
|
||||
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
|
||||
2. Identify the text area where the user will type.
|
||||
3. Based on the full visual context, generate the text the user most likely wants to write.
|
||||
|
||||
Key behavior:
|
||||
- If the text area is EMPTY, draft a full response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
|
||||
- If the text area already has text, continue it naturally.
|
||||
|
||||
Rules:
|
||||
- Output ONLY the text to be inserted. No quotes, no explanations, no meta-commentary.
|
||||
- Be concise but complete — a full thought, not a fragment.
|
||||
- Match the tone and formality of the surrounding context.
|
||||
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
|
||||
- Do NOT describe the screenshot or explain your reasoning.
|
||||
- If you cannot determine what to write, output nothing."""
|
||||
|
||||
APP_CONTEXT_BLOCK = """
|
||||
|
||||
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
|
||||
|
||||
KB_CONTEXT_BLOCK = """
|
||||
|
||||
You also have access to the user's knowledge base documents below. Use them to write more accurate, informed, and contextually relevant text. Do NOT cite or reference the documents explicitly — just let the knowledge inform your writing naturally.
|
||||
|
||||
<knowledge_base>
|
||||
{kb_context}
|
||||
</knowledge_base>"""
|
||||
|
||||
|
||||
def _build_system_prompt(app_name: str, window_title: str, kb_context: str) -> str:
|
||||
"""Assemble the system prompt from optional context blocks."""
|
||||
prompt = VISION_SYSTEM_PROMPT
|
||||
if app_name:
|
||||
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
|
||||
if kb_context:
|
||||
prompt += KB_CONTEXT_BLOCK.format(kb_context=kb_context)
|
||||
return prompt
|
||||
|
||||
|
||||
def _is_vision_unsupported_error(e: Exception) -> bool:
|
||||
"""Check if an exception indicates the model doesn't support vision/images."""
|
||||
msg = str(e).lower()
|
||||
return "content must be a string" in msg or "does not support image" in msg
|
||||
|
||||
|
||||
async def _extract_query_from_screenshot(
|
||||
llm, screenshot_data_url: str,
|
||||
app_name: str = "", window_title: str = "",
|
||||
) -> str | None:
|
||||
"""Ask the Vision LLM to describe what the user is working on.
|
||||
|
||||
Raises vision-unsupported errors so the caller can return a
|
||||
friendly message immediately instead of retrying with astream.
|
||||
"""
|
||||
if app_name:
|
||||
prompt_text = EXTRACT_QUERY_PROMPT_WITH_APP.format(
|
||||
app_name=app_name, window_title=window_title,
|
||||
)
|
||||
else:
|
||||
prompt_text = EXTRACT_QUERY_PROMPT
|
||||
|
||||
try:
|
||||
response = await llm.ainvoke([
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": prompt_text},
|
||||
{"type": "image_url", "image_url": {"url": screenshot_data_url}},
|
||||
]),
|
||||
])
|
||||
query = response.content.strip() if hasattr(response, "content") else ""
|
||||
return query if query else None
|
||||
except Exception as e:
|
||||
if _is_vision_unsupported_error(e):
|
||||
raise
|
||||
logger.warning(f"Failed to extract query from screenshot: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _search_knowledge_base(
|
||||
session: AsyncSession, search_space_id: int, query: str
|
||||
) -> str:
|
||||
"""Search the KB and return formatted context string."""
|
||||
try:
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text=query,
|
||||
top_k=KB_TOP_K,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
char_count = 0
|
||||
for doc in results:
|
||||
title = doc.get("document", {}).get("title", "Untitled")
|
||||
for chunk in doc.get("chunks", []):
|
||||
content = chunk.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
entry = f"[{title}]\n{content}"
|
||||
if char_count + len(entry) > KB_MAX_CHARS:
|
||||
break
|
||||
parts.append(entry)
|
||||
char_count += len(entry)
|
||||
if char_count >= KB_MAX_CHARS:
|
||||
break
|
||||
|
||||
return "\n\n---\n\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning(f"KB search failed, proceeding without context: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def stream_vision_autocomplete(
|
||||
screenshot_data_url: str,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
app_name: str = "",
|
||||
window_title: str = "",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Analyze a screenshot with the vision LLM and stream a text completion.
|
||||
|
||||
Pipeline:
|
||||
1. Extract a search query from the screenshot (non-streaming)
|
||||
2. Search the knowledge base for relevant context
|
||||
3. Stream the final completion with screenshot + KB + app context
|
||||
"""
|
||||
streaming = VercelStreamingService()
|
||||
vision_error_msg = (
|
||||
"The selected model does not support vision. "
|
||||
"Please set a vision-capable model (e.g. GPT-4o, Gemini) in your search space settings."
|
||||
)
|
||||
|
||||
llm = await get_vision_llm(session, search_space_id)
|
||||
if not llm:
|
||||
yield streaming.format_message_start()
|
||||
yield streaming.format_error("No Vision LLM configured for this search space")
|
||||
yield streaming.format_done()
|
||||
return
|
||||
|
||||
kb_context = ""
|
||||
try:
|
||||
query = await _extract_query_from_screenshot(
|
||||
llm, screenshot_data_url, app_name=app_name, window_title=window_title,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
|
||||
yield streaming.format_message_start()
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
yield streaming.format_done()
|
||||
return
|
||||
|
||||
if query:
|
||||
kb_context = await _search_knowledge_base(session, search_space_id, query)
|
||||
|
||||
system_prompt = _build_system_prompt(app_name, window_title, kb_context)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Analyze this screenshot. Understand the full context of what the user is working on, then generate the text they most likely want to write in the active text area.",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": screenshot_data_url},
|
||||
},
|
||||
]),
|
||||
]
|
||||
|
||||
text_started = False
|
||||
text_id = ""
|
||||
try:
|
||||
yield streaming.format_message_start()
|
||||
text_id = streaming.generate_text_id()
|
||||
yield streaming.format_text_start(text_id)
|
||||
text_started = True
|
||||
|
||||
async for chunk in llm.astream(messages):
|
||||
token = chunk.content if hasattr(chunk, "content") else str(chunk)
|
||||
if token:
|
||||
yield streaming.format_text_delta(text_id, token)
|
||||
|
||||
yield streaming.format_text_end(text_id)
|
||||
yield streaming.format_finish()
|
||||
yield streaming.format_done()
|
||||
|
||||
except Exception as e:
|
||||
if text_started:
|
||||
yield streaming.format_text_end(text_id)
|
||||
|
||||
if _is_vision_unsupported_error(e):
|
||||
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
else:
|
||||
logger.error(f"Vision autocomplete streaming error: {e}", exc_info=True)
|
||||
yield streaming.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming.format_done()
|
||||
|
|
@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
|
|||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -396,6 +397,12 @@ async def _index_full_scan(
|
|||
},
|
||||
)
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
|
@ -425,6 +432,21 @@ async def _index_full_scan(
|
|||
elif skip_item(file):
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Dropbox full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -438,6 +460,14 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
@ -458,6 +488,11 @@ async def _index_selected_files(
|
|||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
|
|
@ -482,6 +517,15 @@ async def _index_selected_files(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_path
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, _failed = await _download_and_index(
|
||||
|
|
@ -495,6 +539,14 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from app.indexing_pipeline.indexing_pipeline_service import (
|
|||
PlaceholderInfo,
|
||||
)
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -327,6 +328,12 @@ async def _process_single_file(
|
|||
return 1, 0, 0
|
||||
return 0, 1, 0
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
estimated_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file_name, file.get("size")
|
||||
)
|
||||
await page_limit_service.check_page_limit(user_id, estimated_pages)
|
||||
|
||||
markdown, drive_metadata, error = await download_and_extract_content(
|
||||
drive_client, file
|
||||
)
|
||||
|
|
@ -363,6 +370,9 @@ async def _process_single_file(
|
|||
)
|
||||
await pipeline.index(document, connector_doc, user_llm)
|
||||
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
logger.info(f"Successfully indexed Google Drive file: {file_name}")
|
||||
return 1, 0, 0
|
||||
|
||||
|
|
@ -466,6 +476,11 @@ async def _index_selected_files(
|
|||
|
||||
Returns (indexed_count, skipped_count, errors).
|
||||
"""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
|
|
@ -486,6 +501,15 @@ async def _index_selected_files(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
await _create_drive_placeholders(
|
||||
|
|
@ -507,6 +531,14 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
|
|
@ -545,6 +577,12 @@ async def _index_full_scan(
|
|||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): collect files, run skip checks, track renames
|
||||
# ------------------------------------------------------------------
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_processed = 0
|
||||
|
|
@ -593,6 +631,20 @@ async def _index_full_scan(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Google Drive full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
page_token = next_token
|
||||
|
|
@ -636,6 +688,14 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
@ -686,6 +746,12 @@ async def _index_with_delta_sync(
|
|||
# ------------------------------------------------------------------
|
||||
# Phase 1 (serial): handle removals, collect files for download
|
||||
# ------------------------------------------------------------------
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
|
@ -715,6 +781,20 @@ async def _index_with_delta_sync(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during Google Drive delta sync, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -742,6 +822,14 @@ async def _index_with_delta_sync(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ def _compute_final_pages(
|
|||
actual = page_limit_service.estimate_pages_from_content_length(content_length)
|
||||
return max(estimated_pages, actual)
|
||||
|
||||
|
||||
DEFAULT_EXCLUDE_PATTERNS = [
|
||||
".git",
|
||||
"node_modules",
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from app.indexing_pipeline.connector_document import ConnectorDocument
|
|||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
|
|
@ -291,6 +292,11 @@ async def _index_selected_files(
|
|||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
|
|
@ -311,6 +317,15 @@ async def _index_selected_files(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': page limit would be exceeded")
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, _failed = await _download_and_index(
|
||||
|
|
@ -324,6 +339,14 @@ async def _index_selected_files(
|
|||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
|
|
@ -358,6 +381,12 @@ async def _index_full_scan(
|
|||
},
|
||||
)
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
|
@ -383,6 +412,21 @@ async def _index_full_scan(
|
|||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during OneDrive full scan, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -396,6 +440,14 @@ async def _index_full_scan(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
@ -441,6 +493,12 @@ async def _index_with_delta_sync(
|
|||
|
||||
logger.info(f"Processing {len(changes)} delta changes")
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
remaining_quota = pages_limit - pages_used
|
||||
batch_estimated_pages = 0
|
||||
page_limit_reached = False
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
|
@ -471,6 +529,20 @@ async def _index_with_delta_sync(
|
|||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
change.get("name", ""), change.get("size")
|
||||
)
|
||||
if batch_estimated_pages + file_pages > remaining_quota:
|
||||
if not page_limit_reached:
|
||||
logger.warning(
|
||||
"Page limit reached during OneDrive delta sync, "
|
||||
"skipping remaining files"
|
||||
)
|
||||
page_limit_reached = True
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch_estimated_pages += file_pages
|
||||
files_to_download.append(change)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
|
|
@ -484,6 +556,14 @@ async def _index_with_delta_sync(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if batch_indexed > 0 and files_to_download and batch_estimated_pages > 0:
|
||||
pages_to_deduct = max(
|
||||
1, batch_estimated_pages * batch_indexed // len(files_to_download)
|
||||
)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import hmac
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from random import SystemRandom
|
||||
from string import ascii_letters, digits
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
|
@ -18,6 +20,25 @@ from fastapi import HTTPException
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PKCE_CHARS = ascii_letters + digits + "-._~"
|
||||
_PKCE_RNG = SystemRandom()
|
||||
|
||||
|
||||
def generate_code_verifier(length: int = 128) -> str:
|
||||
"""Generate a PKCE code_verifier (RFC 7636, 43-128 unreserved chars)."""
|
||||
return "".join(_PKCE_RNG.choice(_PKCE_CHARS) for _ in range(length))
|
||||
|
||||
|
||||
def generate_pkce_pair(length: int = 128) -> tuple[str, str]:
|
||||
"""Generate a PKCE code_verifier and its S256 code_challenge."""
|
||||
verifier = generate_code_verifier(length)
|
||||
challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages secure OAuth state parameters with HMAC signatures."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
Prerequisites: PostgreSQL + pgvector only.
|
||||
|
||||
External system boundaries are mocked:
|
||||
- ETL parsing — LlamaParse (external API) and Docling (heavy library)
|
||||
- LLM summarization, text embedding, text chunking (external APIs)
|
||||
- Redis heartbeat (external infrastructure)
|
||||
- Task dispatch is swapped via DI (InlineTaskDispatcher)
|
||||
|
|
@ -11,6 +12,7 @@ External system boundaries are mocked:
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
|
@ -298,3 +300,67 @@ def _mock_redis_heartbeat(monkeypatch):
|
|||
"app.tasks.celery_tasks.document_tasks._run_heartbeat_loop",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
|
||||
_MOCK_ETL_MARKDOWN = "# Mocked Document\n\nThis is mocked ETL content."
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_etl_parsing(monkeypatch):
|
||||
"""Mock ETL parsing services — LlamaParse and Docling are external boundaries.
|
||||
|
||||
Preserves the real contract: empty/corrupt files raise an error just like
|
||||
the actual services would, so tests covering failure paths keep working.
|
||||
"""
|
||||
|
||||
def _reject_empty(file_path: str) -> None:
|
||||
if os.path.getsize(file_path) == 0:
|
||||
raise RuntimeError(f"Cannot parse empty file: {file_path}")
|
||||
|
||||
# -- LlamaParse mock (external API) --------------------------------
|
||||
|
||||
class _FakeMarkdownDoc:
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
class _FakeLlamaParseResult:
|
||||
async def aget_markdown_documents(self, *, split_by_page=False):
|
||||
return [_FakeMarkdownDoc(_MOCK_ETL_MARKDOWN)]
|
||||
|
||||
async def _fake_llamacloud_parse(**kwargs):
|
||||
_reject_empty(kwargs["file_path"])
|
||||
return _FakeLlamaParseResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.document_processors.file_processors.parse_with_llamacloud_retry",
|
||||
_fake_llamacloud_parse,
|
||||
)
|
||||
|
||||
# -- Docling mock (heavy library boundary) -------------------------
|
||||
|
||||
async def _fake_docling_parse(file_path: str, filename: str):
|
||||
_reject_empty(file_path)
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.document_processors.file_processors.parse_with_docling",
|
||||
_fake_docling_parse,
|
||||
)
|
||||
|
||||
class _FakeDoclingResult:
|
||||
class Document:
|
||||
@staticmethod
|
||||
def export_to_markdown():
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
document = Document()
|
||||
|
||||
class _FakeDocumentConverter:
|
||||
def convert(self, file_path):
|
||||
_reject_empty(file_path)
|
||||
return _FakeDoclingResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"docling.document_converter.DocumentConverter",
|
||||
_FakeDocumentConverter,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1015,7 +1015,7 @@ class TestPageLimits:
|
|||
|
||||
(tmp_path / "note.md").write_text("# Hello World\n\nContent here.")
|
||||
|
||||
count, _skipped, _root_folder_id, err = await index_local_folder(
|
||||
count, _skipped, _root_folder_id, _err = await index_local_folder(
|
||||
session=db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=str(db_user.id),
|
||||
|
|
|
|||
|
|
@ -248,12 +248,33 @@ def _folder_dict(file_id: str, name: str) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def _make_page_limit_session(pages_used=0, pages_limit=999_999):
|
||||
"""Build a mock DB session that real PageLimitService can operate against."""
|
||||
|
||||
class _FakeUser:
|
||||
def __init__(self, pu, pl):
|
||||
self.pages_used = pu
|
||||
self.pages_limit = pl
|
||||
|
||||
fake_user = _FakeUser(pages_used, pages_limit)
|
||||
session = AsyncMock()
|
||||
|
||||
def _make_result(*_a, **_kw):
|
||||
r = MagicMock()
|
||||
r.first.return_value = (fake_user.pages_used, fake_user.pages_limit)
|
||||
r.unique.return_value.scalar_one_or_none.return_value = fake_user
|
||||
return r
|
||||
|
||||
session.execute = AsyncMock(side_effect=_make_result)
|
||||
return session, fake_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_scan_mocks(mock_drive_client, monkeypatch):
|
||||
"""Wire up all mocks needed to call _index_full_scan in isolation."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
mock_connector = MagicMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
|
@ -472,7 +493,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
AsyncMock(return_value=MagicMock()),
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
|
|
@ -512,7 +533,7 @@ def selected_files_mocks(mock_drive_client, monkeypatch):
|
|||
"""Wire up mocks for _index_selected_files tests."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session, _ = _make_page_limit_session()
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,680 @@
|
|||
"""Tests for page limit enforcement in connector indexers.
|
||||
|
||||
Covers:
|
||||
A) PageLimitService.estimate_pages_from_metadata — pure function (no mocks)
|
||||
B) Page-limit quota gating in _index_selected_files tested through the
|
||||
real PageLimitService with a mock DB session (system boundary).
|
||||
Google Drive is the primary, with OneDrive/Dropbox smoke tests.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_USER_ID = "00000000-0000-0000-0000-000000000001"
|
||||
_CONNECTOR_ID = 42
|
||||
_SEARCH_SPACE_ID = 1
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# A) PageLimitService.estimate_pages_from_metadata — pure function
|
||||
# No mocks: it's a staticmethod with no I/O.
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestEstimatePagesFromMetadata:
|
||||
"""Vertical slices for the page estimation staticmethod."""
|
||||
|
||||
def test_pdf_100kb_returns_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 100 * 1024) == 1
|
||||
|
||||
def test_pdf_500kb_returns_5(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 500 * 1024) == 5
|
||||
|
||||
def test_pdf_1mb(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 1024 * 1024) == 10
|
||||
|
||||
def test_docx_50kb_returns_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".docx", 50 * 1024) == 1
|
||||
|
||||
def test_docx_200kb(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".docx", 200 * 1024) == 4
|
||||
|
||||
def test_pptx_uses_200kb_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pptx", 600 * 1024) == 3
|
||||
|
||||
def test_xlsx_uses_100kb_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".xlsx", 300 * 1024) == 3
|
||||
|
||||
def test_txt_uses_3000_bytes_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".txt", 9000) == 3
|
||||
|
||||
def test_image_always_returns_1(self):
|
||||
for ext in (".jpg", ".png", ".gif", ".webp"):
|
||||
assert PageLimitService.estimate_pages_from_metadata(ext, 5_000_000) == 1
|
||||
|
||||
def test_audio_uses_1mb_per_page(self):
|
||||
assert (
|
||||
PageLimitService.estimate_pages_from_metadata(".mp3", 3 * 1024 * 1024) == 3
|
||||
)
|
||||
|
||||
def test_video_uses_5mb_per_page(self):
|
||||
assert (
|
||||
PageLimitService.estimate_pages_from_metadata(".mp4", 15 * 1024 * 1024) == 3
|
||||
)
|
||||
|
||||
def test_unknown_ext_uses_80kb_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".xyz", 160 * 1024) == 2
|
||||
|
||||
def test_zero_size_returns_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 0) == 1
|
||||
|
||||
def test_negative_size_returns_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", -500) == 1
|
||||
|
||||
def test_minimum_is_always_1(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".pdf", 50) == 1
|
||||
|
||||
def test_epub_uses_50kb_per_page(self):
|
||||
assert PageLimitService.estimate_pages_from_metadata(".epub", 250 * 1024) == 5
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# B) Page-limit enforcement in connector indexers
|
||||
# System boundary mocked: DB session (for PageLimitService)
|
||||
# System boundary mocked: external API clients, download/ETL
|
||||
# NOT mocked: PageLimitService itself (our own code)
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class _FakeUser:
|
||||
"""Stands in for the User ORM model at the DB boundary."""
|
||||
|
||||
def __init__(self, pages_used: int = 0, pages_limit: int = 100):
|
||||
self.pages_used = pages_used
|
||||
self.pages_limit = pages_limit
|
||||
|
||||
|
||||
def _make_page_limit_session(pages_used: int = 0, pages_limit: int = 100):
|
||||
"""Build a mock DB session that real PageLimitService can operate against.
|
||||
|
||||
Every ``session.execute()`` returns a result compatible with both
|
||||
``get_page_usage`` (.first() → tuple) and ``update_page_usage``
|
||||
(.unique().scalar_one_or_none() → User-like).
|
||||
"""
|
||||
fake_user = _FakeUser(pages_used, pages_limit)
|
||||
session = AsyncMock()
|
||||
|
||||
def _make_result(*_args, **_kwargs):
|
||||
result = MagicMock()
|
||||
result.first.return_value = (fake_user.pages_used, fake_user.pages_limit)
|
||||
result.unique.return_value.scalar_one_or_none.return_value = fake_user
|
||||
return result
|
||||
|
||||
session.execute = AsyncMock(side_effect=_make_result)
|
||||
return session, fake_user
|
||||
|
||||
|
||||
def _make_gdrive_file(file_id: str, name: str, size: int = 80 * 1024) -> dict:
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": name,
|
||||
"mimeType": "application/octet-stream",
|
||||
"size": str(size),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Google Drive: _index_selected_files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gdrive_selected_mocks(monkeypatch):
|
||||
"""Mocks for Google Drive _index_selected_files — only system boundaries."""
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
session, fake_user = _make_page_limit_session(0, 100)
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, file_id):
|
||||
return get_file_results.get(file_id, (None, f"Not configured: {file_id}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
|
||||
return {
|
||||
"mod": _mod,
|
||||
"session": session,
|
||||
"fake_user": fake_user,
|
||||
"get_file_results": get_file_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_gdrive_selected(mocks, file_ids):
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
_index_selected_files,
|
||||
)
|
||||
|
||||
return await _index_selected_files(
|
||||
MagicMock(),
|
||||
mocks["session"],
|
||||
file_ids,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_gdrive_files_within_quota_are_downloaded(gdrive_selected_mocks):
|
||||
"""Files whose cumulative estimated pages fit within remaining quota
|
||||
are sent to _download_and_index."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for fid in ("f1", "f2", "f3"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (3, 0)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")]
|
||||
)
|
||||
|
||||
assert indexed == 3
|
||||
assert errors == []
|
||||
call_files = m["download_and_index_mock"].call_args[0][2]
|
||||
assert len(call_files) == 3
|
||||
|
||||
|
||||
async def test_gdrive_files_exceeding_quota_rejected(gdrive_selected_mocks):
|
||||
"""Files whose pages would exceed remaining quota are rejected."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 98
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
m["get_file_results"]["big"] = (
|
||||
_make_gdrive_file("big", "huge.pdf", size=500 * 1024),
|
||||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(m, [("big", "huge.pdf")])
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 1
|
||||
assert "page limit" in errors[0].lower()
|
||||
|
||||
|
||||
async def test_gdrive_quota_mix_partial_indexing(gdrive_selected_mocks):
|
||||
"""3rd file pushes over quota → only first two indexed."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 2
|
||||
|
||||
for fid in ("f1", "f2", "f3"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")]
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert len(errors) == 1
|
||||
call_files = m["download_and_index_mock"].call_args[0][2]
|
||||
assert {f["id"] for f in call_files} == {"f1", "f2"}
|
||||
|
||||
|
||||
async def test_gdrive_proportional_page_deduction(gdrive_selected_mocks):
|
||||
"""Pages deducted are proportional to successfully indexed files."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for fid in ("f1", "f2", "f3", "f4"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (2, 2)
|
||||
|
||||
await _run_gdrive_selected(
|
||||
m,
|
||||
[("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz"), ("f4", "f4.xyz")],
|
||||
)
|
||||
|
||||
assert m["fake_user"].pages_used == 2
|
||||
|
||||
|
||||
async def test_gdrive_no_deduction_when_nothing_indexed(gdrive_selected_mocks):
|
||||
"""If batch_indexed == 0, user's pages_used stays unchanged."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 5
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
m["get_file_results"]["f1"] = (
|
||||
_make_gdrive_file("f1", "f1.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (0, 1)
|
||||
|
||||
await _run_gdrive_selected(m, [("f1", "f1.xyz")])
|
||||
|
||||
assert m["fake_user"].pages_used == 5
|
||||
|
||||
|
||||
async def test_gdrive_zero_quota_rejects_all(gdrive_selected_mocks):
|
||||
"""When pages_used == pages_limit, every file is rejected."""
|
||||
m = gdrive_selected_mocks
|
||||
m["fake_user"].pages_used = 100
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for fid in ("f1", "f2"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_gdrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz")]
|
||||
)
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Google Drive: _index_full_scan
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gdrive_full_scan_mocks(monkeypatch):
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
session, fake_user = _make_page_limit_session(0, 100)
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_mock = AsyncMock(return_value=([], 0))
|
||||
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 0, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock())
|
||||
)
|
||||
|
||||
return {
|
||||
"mod": _mod,
|
||||
"session": session,
|
||||
"fake_user": fake_user,
|
||||
"task_logger": mock_task_logger,
|
||||
"download_mock": download_mock,
|
||||
"batch_mock": batch_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_gdrive_full_scan(mocks, max_files=500):
|
||||
from app.tasks.connector_indexers.google_drive_indexer import _index_full_scan
|
||||
|
||||
return await _index_full_scan(
|
||||
MagicMock(),
|
||||
mocks["session"],
|
||||
MagicMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"folder-root",
|
||||
"My Folder",
|
||||
mocks["task_logger"],
|
||||
MagicMock(),
|
||||
max_files,
|
||||
include_subfolders=False,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_gdrive_full_scan_skips_over_quota(gdrive_full_scan_mocks, monkeypatch):
|
||||
m = gdrive_full_scan_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 2
|
||||
|
||||
page_files = [
|
||||
_make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(5)
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
m["mod"],
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
m["download_mock"].return_value = ([], 0)
|
||||
m["batch_mock"].return_value = ([], 2, 0)
|
||||
|
||||
_indexed, skipped = await _run_gdrive_full_scan(m)
|
||||
|
||||
call_files = m["download_mock"].call_args[0][1]
|
||||
assert len(call_files) == 2
|
||||
assert skipped == 3
|
||||
|
||||
|
||||
async def test_gdrive_full_scan_deducts_after_indexing(
|
||||
gdrive_full_scan_mocks, monkeypatch
|
||||
):
|
||||
m = gdrive_full_scan_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
page_files = [
|
||||
_make_gdrive_file(f"f{i}", f"file{i}.xyz", size=80 * 1024) for i in range(3)
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
m["mod"],
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None, None)),
|
||||
)
|
||||
mock_docs = [MagicMock() for _ in range(3)]
|
||||
m["download_mock"].return_value = (mock_docs, 0)
|
||||
m["batch_mock"].return_value = ([], 3, 0)
|
||||
|
||||
await _run_gdrive_full_scan(m)
|
||||
|
||||
assert m["fake_user"].pages_used == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Google Drive: _index_with_delta_sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_gdrive_delta_sync_skips_over_quota(monkeypatch):
|
||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||
|
||||
session, _ = _make_page_limit_session(0, 2)
|
||||
|
||||
changes = [
|
||||
{
|
||||
"fileId": f"mod{i}",
|
||||
"file": _make_gdrive_file(f"mod{i}", f"mod{i}.xyz", size=80 * 1024),
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
_mod,
|
||||
"fetch_all_changes",
|
||||
AsyncMock(return_value=(changes, "new-token", None)),
|
||||
)
|
||||
monkeypatch.setattr(_mod, "categorize_change", lambda change: "modified")
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_mock = AsyncMock(return_value=([], 0))
|
||||
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
|
||||
|
||||
batch_mock = AsyncMock(return_value=([], 2, 0))
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.index_batch_parallel = batch_mock
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock())
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
_indexed, skipped = await _mod._index_with_delta_sync(
|
||||
MagicMock(),
|
||||
session,
|
||||
MagicMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"folder-root",
|
||||
"start-token",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
call_files = download_mock.call_args[0][1]
|
||||
assert len(call_files) == 2
|
||||
assert skipped == 3
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# C) OneDrive smoke tests — verify page limit wiring
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def _make_onedrive_file(file_id: str, name: str, size: int = 80 * 1024) -> dict:
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": name,
|
||||
"file": {"mimeType": "application/octet-stream"},
|
||||
"size": str(size),
|
||||
"lastModifiedDateTime": "2026-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def onedrive_selected_mocks(monkeypatch):
|
||||
import app.tasks.connector_indexers.onedrive_indexer as _mod
|
||||
|
||||
session, fake_user = _make_page_limit_session(0, 100)
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, file_id):
|
||||
return get_file_results.get(file_id, (None, f"Not found: {file_id}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"fake_user": fake_user,
|
||||
"get_file_results": get_file_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_onedrive_selected(mocks, file_ids):
|
||||
from app.tasks.connector_indexers.onedrive_indexer import _index_selected_files
|
||||
|
||||
return await _index_selected_files(
|
||||
MagicMock(),
|
||||
mocks["session"],
|
||||
file_ids,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_onedrive_over_quota_rejected(onedrive_selected_mocks):
|
||||
"""OneDrive: files exceeding quota produce errors, not downloads."""
|
||||
m = onedrive_selected_mocks
|
||||
m["fake_user"].pages_used = 99
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
m["get_file_results"]["big"] = (
|
||||
_make_onedrive_file("big", "huge.pdf", size=500 * 1024),
|
||||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_onedrive_selected(m, [("big", "huge.pdf")])
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 1
|
||||
assert "page limit" in errors[0].lower()
|
||||
|
||||
|
||||
async def test_onedrive_deducts_after_success(onedrive_selected_mocks):
|
||||
"""OneDrive: pages_used increases after successful indexing."""
|
||||
m = onedrive_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for fid in ("f1", "f2"):
|
||||
m["get_file_results"][fid] = (
|
||||
_make_onedrive_file(fid, f"{fid}.xyz", size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
await _run_onedrive_selected(m, [("f1", "f1.xyz"), ("f2", "f2.xyz")])
|
||||
|
||||
assert m["fake_user"].pages_used == 2
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# D) Dropbox smoke tests — verify page limit wiring
|
||||
# ===================================================================
|
||||
|
||||
|
||||
def _make_dropbox_file(file_path: str, name: str, size: int = 80 * 1024) -> dict:
|
||||
return {
|
||||
"id": f"id:{file_path}",
|
||||
"name": name,
|
||||
".tag": "file",
|
||||
"path_lower": file_path,
|
||||
"size": str(size),
|
||||
"server_modified": "2026-01-01T00:00:00Z",
|
||||
"content_hash": f"hash_{name}",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dropbox_selected_mocks(monkeypatch):
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
session, fake_user = _make_page_limit_session(0, 100)
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, file_path):
|
||||
return get_file_results.get(file_path, (None, f"Not found: {file_path}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_path", _fake_get_file)
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
pipeline_mock = MagicMock()
|
||||
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||
monkeypatch.setattr(
|
||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock)
|
||||
)
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"fake_user": fake_user,
|
||||
"get_file_results": get_file_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_dropbox_selected(mocks, file_paths):
|
||||
from app.tasks.connector_indexers.dropbox_indexer import _index_selected_files
|
||||
|
||||
return await _index_selected_files(
|
||||
MagicMock(),
|
||||
mocks["session"],
|
||||
file_paths,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_dropbox_over_quota_rejected(dropbox_selected_mocks):
|
||||
"""Dropbox: files exceeding quota produce errors, not downloads."""
|
||||
m = dropbox_selected_mocks
|
||||
m["fake_user"].pages_used = 99
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
m["get_file_results"]["/huge.pdf"] = (
|
||||
_make_dropbox_file("/huge.pdf", "huge.pdf", size=500 * 1024),
|
||||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_dropbox_selected(
|
||||
m, [("/huge.pdf", "huge.pdf")]
|
||||
)
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 1
|
||||
assert "page limit" in errors[0].lower()
|
||||
|
||||
|
||||
async def test_dropbox_deducts_after_success(dropbox_selected_mocks):
|
||||
"""Dropbox: pages_used increases after successful indexing."""
|
||||
m = dropbox_selected_mocks
|
||||
m["fake_user"].pages_used = 0
|
||||
m["fake_user"].pages_limit = 100
|
||||
|
||||
for name in ("f1.xyz", "f2.xyz"):
|
||||
path = f"/{name}"
|
||||
m["get_file_results"][path] = (
|
||||
_make_dropbox_file(path, name, size=80 * 1024),
|
||||
None,
|
||||
)
|
||||
m["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
await _run_dropbox_selected(m, [("/f1.xyz", "f1.xyz"), ("/f2.xyz", "f2.xyz")])
|
||||
|
||||
assert m["fake_user"].pages_used == 2
|
||||
Loading…
Add table
Add a link
Reference in a new issue