feat: add processing mode support for document uploads and ETL pipeline, improded error handling ux
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

- Introduced a `ProcessingMode` enum to differentiate between basic and premium processing modes.
- Updated `EtlRequest` to include a `processing_mode` field, defaulting to basic.
- Enhanced ETL pipeline services to utilize the selected processing mode for Azure Document Intelligence and LlamaCloud parsing.
- Modified various routes and services to handle processing mode, affecting document upload and indexing tasks.
- Improved error handling and logging to include processing mode details.
- Added tests to validate processing mode functionality and its impact on ETL operations.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-14 21:26:00 -07:00
parent b659f41bab
commit 656e061f84
104 changed files with 1900 additions and 909 deletions

View file

@ -2,12 +2,15 @@ import asyncio
import gc
import logging
import time
import uuid
from collections import defaultdict
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from threading import Lock
import redis
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from limits.storage import MemoryStorage
@ -32,6 +35,7 @@ from app.config import (
initialize_vision_llm_router,
)
from app.db import User, create_db_and_tables, get_async_session
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
@ -39,6 +43,8 @@ from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
from app.utils.perf import get_perf_logger, log_system_snapshot
_error_logger = logging.getLogger("surfsense.errors")
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
@ -61,13 +67,137 @@ limiter = Limiter(
)
def _get_request_id(request: Request) -> str:
"""Return the request ID from state, header, or generate a new one."""
if hasattr(request.state, "request_id"):
return request.state.request_id
return request.headers.get("X-Request-ID", f"req_{uuid.uuid4().hex[:12]}")
def _build_error_response(
status_code: int,
message: str,
*,
code: str = "INTERNAL_ERROR",
request_id: str = "",
extra_headers: dict[str, str] | None = None,
) -> JSONResponse:
"""Build the standardized error envelope (new ``error`` + legacy ``detail``)."""
body = {
"error": {
"code": code,
"message": message,
"status": status_code,
"request_id": request_id,
"timestamp": datetime.now(UTC).isoformat(),
"report_url": ISSUES_URL,
},
"detail": message,
}
headers = {"X-Request-ID": request_id}
if extra_headers:
headers.update(extra_headers)
return JSONResponse(status_code=status_code, content=body, headers=headers)
# ---------------------------------------------------------------------------
# Global exception handlers
# ---------------------------------------------------------------------------
def _surfsense_error_handler(request: Request, exc: SurfSenseError) -> JSONResponse:
"""Handle our own structured exceptions."""
rid = _get_request_id(request)
if exc.status_code >= 500:
_error_logger.error(
"[%s] %s - %s: %s",
rid,
request.url.path,
exc.code,
exc,
exc_info=True,
)
message = exc.message if exc.safe_for_client else GENERIC_5XX_MESSAGE
return _build_error_response(
exc.status_code, message, code=exc.code, request_id=rid
)
def _http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""Wrap FastAPI/Starlette HTTPExceptions into the standard envelope."""
rid = _get_request_id(request)
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
if exc.status_code >= 500:
_error_logger.error(
"[%s] %s - HTTPException %d: %s",
rid,
request.url.path,
exc.status_code,
detail,
)
detail = GENERIC_5XX_MESSAGE
code = _status_to_code(exc.status_code, detail)
return _build_error_response(exc.status_code, detail, code=code, request_id=rid)
def _validation_error_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
"""Return 422 with field-level detail in the standard envelope."""
rid = _get_request_id(request)
fields = []
for err in exc.errors():
loc = " -> ".join(str(part) for part in err.get("loc", []))
fields.append(f"{loc}: {err.get('msg', 'invalid')}")
message = (
f"Validation failed: {'; '.join(fields)}" if fields else "Validation failed."
)
return _build_error_response(422, message, code="VALIDATION_ERROR", request_id=rid)
def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Catch-all: log full traceback, return sanitized 500."""
rid = _get_request_id(request)
_error_logger.error(
"[%s] Unhandled exception on %s %s",
rid,
request.method,
request.url.path,
exc_info=True,
)
return _build_error_response(
500, GENERIC_5XX_MESSAGE, code="INTERNAL_ERROR", request_id=rid
)
def _status_to_code(status_code: int, detail: str = "") -> str:
if detail == "RATE_LIMIT_EXCEEDED":
return "RATE_LIMIT_EXCEEDED"
mapping = {
400: "BAD_REQUEST",
401: "UNAUTHORIZED",
403: "FORBIDDEN",
404: "NOT_FOUND",
405: "METHOD_NOT_ALLOWED",
409: "CONFLICT",
422: "VALIDATION_ERROR",
429: "RATE_LIMIT_EXCEEDED",
}
return mapping.get(
status_code, "INTERNAL_ERROR" if status_code >= 500 else "CLIENT_ERROR"
)
def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
"""Custom 429 handler that returns JSON matching our frontend error format."""
"""Custom 429 handler that returns JSON matching our error envelope."""
rid = _get_request_id(request)
retry_after = exc.detail.split("per")[-1].strip() if exc.detail else "60"
return JSONResponse(
status_code=429,
content={"detail": "RATE_LIMIT_EXCEEDED"},
headers={"Retry-After": retry_after},
return _build_error_response(
429,
"Too many requests. Please slow down and try again.",
code="RATE_LIMIT_EXCEEDED",
request_id=rid,
extra_headers={"Retry-After": retry_after},
)
@ -258,6 +388,33 @@ app = FastAPI(lifespan=lifespan)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Register structured global exception handlers (order matters: most specific first)
app.add_exception_handler(SurfSenseError, _surfsense_error_handler)
app.add_exception_handler(RequestValidationError, _validation_error_handler)
app.add_exception_handler(HTTPException, _http_exception_handler)
app.add_exception_handler(Exception, _unhandled_exception_handler)
# ---------------------------------------------------------------------------
# Request-ID middleware
# ---------------------------------------------------------------------------
class RequestIDMiddleware(BaseHTTPMiddleware):
"""Attach a unique request ID to every request and echo it in the response."""
async def dispatch(
self, request: StarletteRequest, call_next: RequestResponseEndpoint
) -> StarletteResponse:
request_id = request.headers.get("X-Request-ID", f"req_{uuid.uuid4().hex[:12]}")
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
app.add_middleware(RequestIDMiddleware)
# ---------------------------------------------------------------------------
# Request-level performance middleware

View file

@ -1,10 +1,37 @@
from enum import StrEnum
from pydantic import BaseModel, field_validator
class ProcessingMode(StrEnum):
BASIC = "basic"
PREMIUM = "premium"
@classmethod
def coerce(cls, value: str | None) -> "ProcessingMode":
if value is None:
return cls.BASIC
try:
return cls(value.lower())
except ValueError:
return cls.BASIC
@property
def page_multiplier(self) -> int:
return _PAGE_MULTIPLIERS[self]
_PAGE_MULTIPLIERS: dict["ProcessingMode", int] = {
ProcessingMode.BASIC: 1,
ProcessingMode.PREMIUM: 10,
}
class EtlRequest(BaseModel):
file_path: str
filename: str
estimated_pages: int = 0
processing_mode: ProcessingMode = ProcessingMode.BASIC
@field_validator("filename")
@classmethod

View file

@ -145,13 +145,17 @@ class EtlPipelineService:
and getattr(app_config, "AZURE_DI_KEY", None)
)
mode_value = request.processing_mode.value
if azure_configured and ext in AZURE_DI_DOCUMENT_EXTENSIONS:
try:
from app.etl_pipeline.parsers.azure_doc_intelligence import (
parse_with_azure_doc_intelligence,
)
return await parse_with_azure_doc_intelligence(request.file_path)
return await parse_with_azure_doc_intelligence(
request.file_path, processing_mode=mode_value
)
except Exception:
logging.warning(
"Azure Document Intelligence failed for %s, "
@ -162,4 +166,6 @@ class EtlPipelineService:
from app.etl_pipeline.parsers.llamacloud import parse_with_llamacloud
return await parse_with_llamacloud(request.file_path, request.estimated_pages)
return await parse_with_llamacloud(
request.file_path, request.estimated_pages, processing_mode=mode_value
)

View file

@ -10,7 +10,15 @@ BASE_DELAY = 10
MAX_DELAY = 120
async def parse_with_azure_doc_intelligence(file_path: str) -> str:
AZURE_MODEL_BY_MODE = {
"basic": "prebuilt-read",
"premium": "prebuilt-layout",
}
async def parse_with_azure_doc_intelligence(
file_path: str, processing_mode: str = "basic"
) -> str:
from azure.ai.documentintelligence.aio import DocumentIntelligenceClient
from azure.ai.documentintelligence.models import DocumentContentFormat
from azure.core.credentials import AzureKeyCredential
@ -21,9 +29,15 @@ async def parse_with_azure_doc_intelligence(file_path: str) -> str:
ServiceResponseError,
)
model_id = AZURE_MODEL_BY_MODE.get(processing_mode, "prebuilt-read")
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
retryable_exceptions = (ServiceRequestError, ServiceResponseError)
logging.info(
f"Azure Document Intelligence using model={model_id} "
f"(mode={processing_mode}, file={file_size_mb:.1f}MB)"
)
last_exception = None
attempt_errors: list[str] = []
@ -36,7 +50,7 @@ async def parse_with_azure_doc_intelligence(file_path: str) -> str:
async with client:
with open(file_path, "rb") as f:
poller = await client.begin_analyze_document(
"prebuilt-layout",
model_id,
body=f,
output_content_format=DocumentContentFormat.MARKDOWN,
)

View file

@ -16,8 +16,15 @@ from app.etl_pipeline.constants import (
calculate_upload_timeout,
)
LLAMA_TIER_BY_MODE = {
"basic": "cost_effective",
"premium": "agentic_plus",
}
async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
async def parse_with_llamacloud(
file_path: str, estimated_pages: int, processing_mode: str = "basic"
) -> str:
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
@ -34,10 +41,12 @@ async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
pool=120.0,
)
tier = LLAMA_TIER_BY_MODE.get(processing_mode, "cost_effective")
logging.info(
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
f"job_timeout={job_timeout:.0f}s"
f"job_timeout={job_timeout:.0f}s, tier={tier} (mode={processing_mode})"
)
last_exception = None
@ -56,6 +65,7 @@ async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
job_timeout_in_seconds=job_timeout,
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
custom_client=custom_client,
tier=tier,
)
result = await parser.aparse(file_path)

View file

@ -0,0 +1,104 @@
"""Structured error hierarchy for SurfSense.
Every error response follows a backward-compatible contract:
{
"error": {
"code": "SOME_ERROR_CODE",
"message": "Human-readable, client-safe message.",
"status": 422,
"request_id": "req_...",
"timestamp": "2026-04-14T12:00:00Z",
"report_url": "https://github.com/MODSetter/SurfSense/issues"
},
"detail": "Human-readable, client-safe message." # legacy compat
}
"""
from __future__ import annotations
ISSUES_URL = "https://github.com/MODSetter/SurfSense/issues"
GENERIC_5XX_MESSAGE = (
"An internal error occurred. Please try again or report this issue if it persists."
)
class SurfSenseError(Exception):
"""Base exception that global handlers translate into the structured envelope."""
def __init__(
self,
message: str = GENERIC_5XX_MESSAGE,
*,
code: str = "INTERNAL_ERROR",
status_code: int = 500,
safe_for_client: bool = True,
) -> None:
super().__init__(message)
self.message = message
self.code = code
self.status_code = status_code
self.safe_for_client = safe_for_client
class ConnectorError(SurfSenseError):
def __init__(self, message: str, *, code: str = "CONNECTOR_ERROR") -> None:
super().__init__(message, code=code, status_code=502)
class DatabaseError(SurfSenseError):
def __init__(
self,
message: str = "A database error occurred.",
*,
code: str = "DATABASE_ERROR",
) -> None:
super().__init__(message, code=code, status_code=500)
class ConfigurationError(SurfSenseError):
def __init__(
self,
message: str = "A configuration error occurred.",
*,
code: str = "CONFIGURATION_ERROR",
) -> None:
super().__init__(message, code=code, status_code=500)
class ExternalServiceError(SurfSenseError):
def __init__(
self,
message: str = "An external service is unavailable.",
*,
code: str = "EXTERNAL_SERVICE_ERROR",
) -> None:
super().__init__(message, code=code, status_code=502)
class NotFoundError(SurfSenseError):
def __init__(
self,
message: str = "The requested resource was not found.",
*,
code: str = "NOT_FOUND",
) -> None:
super().__init__(message, code=code, status_code=404)
class ForbiddenError(SurfSenseError):
def __init__(
self,
message: str = "You don't have permission to access this resource.",
*,
code: str = "FORBIDDEN",
) -> None:
super().__init__(message, code=code, status_code=403)
class ValidationError(SurfSenseError):
def __init__(
self, message: str = "Validation failed.", *, code: str = "VALIDATION_ERROR"
) -> None:
super().__init__(message, code=code, status_code=422)

View file

@ -124,6 +124,7 @@ async def create_documents_file_upload(
search_space_id: int = Form(...),
should_summarize: bool = Form(False),
use_vision_llm: bool = Form(False),
processing_mode: str = Form("basic"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
@ -142,12 +143,15 @@ async def create_documents_file_upload(
from datetime import datetime
from app.db import DocumentStatus
from app.etl_pipeline.etl_document import ProcessingMode
from app.tasks.document_processors.base import (
check_document_by_unique_identifier,
get_current_timestamp,
)
from app.utils.document_converters import generate_unique_identifier_hash
validated_mode = ProcessingMode.coerce(processing_mode)
try:
await check_permission(
session,
@ -274,6 +278,7 @@ async def create_documents_file_upload(
user_id=str(user.id),
should_summarize=should_summarize,
use_vision_llm=use_vision_llm,
processing_mode=validated_mode.value,
)
return {
@ -1493,6 +1498,7 @@ async def folder_upload(
root_folder_id: int | None = Form(None),
enable_summary: bool = Form(False),
use_vision_llm: bool = Form(False),
processing_mode: str = Form("basic"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
@ -1504,6 +1510,10 @@ async def folder_upload(
import json
import tempfile
from app.etl_pipeline.etl_document import ProcessingMode
validated_mode = ProcessingMode.coerce(processing_mode)
await check_permission(
session,
user,
@ -1558,6 +1568,7 @@ async def folder_upload(
watched_metadata = {
"watched": True,
"folder_path": folder_name,
"processing_mode": validated_mode.value,
}
existing_root = (
await session.execute(
@ -1621,6 +1632,7 @@ async def folder_upload(
enable_summary=enable_summary,
use_vision_llm=use_vision_llm,
file_mappings=list(file_mappings),
processing_mode=validated_mode.value,
)
return {

View file

@ -20,6 +20,7 @@ class TaskDispatcher(Protocol):
user_id: str,
should_summarize: bool = False,
use_vision_llm: bool = False,
processing_mode: str = "basic",
) -> None: ...
@ -36,6 +37,7 @@ class CeleryTaskDispatcher:
user_id: str,
should_summarize: bool = False,
use_vision_llm: bool = False,
processing_mode: str = "basic",
) -> None:
from app.tasks.celery_tasks.document_tasks import (
process_file_upload_with_document_task,
@ -49,6 +51,7 @@ class CeleryTaskDispatcher:
user_id=user_id,
should_summarize=should_summarize,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)

View file

@ -780,6 +780,7 @@ def process_file_upload_with_document_task(
user_id: str,
should_summarize: bool = False,
use_vision_llm: bool = False,
processing_mode: str = "basic",
):
"""
Celery task to process uploaded file with existing pending document.
@ -836,6 +837,7 @@ def process_file_upload_with_document_task(
user_id,
should_summarize=should_summarize,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
)
logger.info(
@ -873,6 +875,7 @@ async def _process_file_with_document(
user_id: str,
should_summarize: bool = False,
use_vision_llm: bool = False,
processing_mode: str = "basic",
):
"""
Process file and update existing pending document status.
@ -976,6 +979,7 @@ async def _process_file_with_document(
notification=notification,
should_summarize=should_summarize,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
# Update notification on success
@ -1434,6 +1438,7 @@ def index_uploaded_folder_files_task(
enable_summary: bool,
file_mappings: list[dict],
use_vision_llm: bool = False,
processing_mode: str = "basic",
):
"""Celery task to index files uploaded from the desktop app."""
loop = asyncio.new_event_loop()
@ -1448,6 +1453,7 @@ def index_uploaded_folder_files_task(
enable_summary=enable_summary,
file_mappings=file_mappings,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
)
finally:
@ -1462,6 +1468,7 @@ async def _index_uploaded_folder_files_async(
enable_summary: bool,
file_mappings: list[dict],
use_vision_llm: bool = False,
processing_mode: str = "basic",
):
"""Run upload-based folder indexing with notification + heartbeat."""
file_count = len(file_mappings)
@ -1512,6 +1519,7 @@ async def _index_uploaded_folder_files_async(
file_mappings=file_mappings,
on_heartbeat_callback=_heartbeat_progress,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
if notification:

View file

@ -60,14 +60,16 @@ async def _check_page_limit_or_skip(
page_limit_service: PageLimitService,
user_id: str,
file_path: str,
) -> int:
page_multiplier: int = 1,
) -> tuple[int, int]:
"""Estimate pages and check the limit; raises PageLimitExceededError if over quota.
Returns the estimated page count on success.
Returns (estimated_pages, billable_pages).
"""
estimated = _estimate_pages_safe(page_limit_service, file_path)
await page_limit_service.check_page_limit(user_id, estimated)
return estimated
billable = estimated * page_multiplier
await page_limit_service.check_page_limit(user_id, billable)
return estimated, billable
def _compute_final_pages(
@ -153,17 +155,20 @@ def scan_folder(
return files
async def _read_file_content(file_path: str, filename: str, *, vision_llm=None) -> str:
async def _read_file_content(
file_path: str, filename: str, *, vision_llm=None, processing_mode: str = "basic"
) -> str:
"""Read file content via the unified ETL pipeline.
All file types (plaintext, audio, direct-convert, document, image) are
handled by ``EtlPipelineService``.
"""
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
mode = ProcessingMode.coerce(processing_mode)
result = await EtlPipelineService(vision_llm=vision_llm).extract(
EtlRequest(file_path=file_path, filename=filename)
EtlRequest(file_path=file_path, filename=filename, processing_mode=mode)
)
return result.markdown_content
@ -201,12 +206,15 @@ async def _compute_file_content_hash(
search_space_id: int,
*,
vision_llm=None,
processing_mode: str = "basic",
) -> tuple[str, str]:
"""Read a file (via ETL if needed) and compute its content hash.
Returns (content_text, content_hash).
"""
content = await _read_file_content(file_path, filename, vision_llm=vision_llm)
content = await _read_file_content(
file_path, filename, vision_llm=vision_llm, processing_mode=processing_mode
)
return content, _content_hash(content, search_space_id)
@ -694,7 +702,7 @@ async def index_local_folder(
continue
try:
estimated_pages = await _check_page_limit_or_skip(
estimated_pages, _billable = await _check_page_limit_or_skip(
page_limit_service, user_id, file_path_abs
)
except PageLimitExceededError:
@ -730,7 +738,7 @@ async def index_local_folder(
await create_version_snapshot(session, existing_document)
else:
try:
estimated_pages = await _check_page_limit_or_skip(
estimated_pages, _billable = await _check_page_limit_or_skip(
page_limit_service, user_id, file_path_abs
)
except PageLimitExceededError:
@ -1080,7 +1088,7 @@ async def _index_single_file(
page_limit_service = PageLimitService(session)
try:
estimated_pages = await _check_page_limit_or_skip(
estimated_pages, _billable = await _check_page_limit_or_skip(
page_limit_service, user_id, str(full_path)
)
except PageLimitExceededError as e:
@ -1271,6 +1279,7 @@ async def index_uploaded_files(
file_mappings: list[dict],
on_heartbeat_callback: HeartbeatCallbackType | None = None,
use_vision_llm: bool = False,
processing_mode: str = "basic",
) -> tuple[int, int, str | None]:
"""Index files uploaded from the desktop app via temp paths.
@ -1281,12 +1290,16 @@ async def index_uploaded_files(
Returns ``(indexed_count, failed_count, error_summary_or_none)``.
"""
from app.etl_pipeline.etl_document import ProcessingMode
mode = ProcessingMode.coerce(processing_mode)
task_logger = TaskLoggingService(session, search_space_id)
log_entry = await task_logger.log_task_start(
task_name="local_folder_indexing",
source="uploaded_folder_indexing",
message=f"Indexing {len(file_mappings)} uploaded file(s) for {folder_name}",
metadata={"file_count": len(file_mappings)},
metadata={"file_count": len(file_mappings), "processing_mode": mode.value},
)
try:
@ -1350,8 +1363,11 @@ async def index_uploaded_files(
continue
try:
estimated_pages = await _check_page_limit_or_skip(
page_limit_service, user_id, temp_path
estimated_pages, _billable_pages = await _check_page_limit_or_skip(
page_limit_service,
user_id,
temp_path,
page_multiplier=mode.page_multiplier,
)
except PageLimitExceededError:
logger.warning(f"Page limit exceeded, skipping: {relative_path}")
@ -1364,6 +1380,7 @@ async def index_uploaded_files(
filename,
search_space_id,
vision_llm=vision_llm_instance,
processing_mode=mode.value,
)
except Exception as e:
logger.warning(f"Could not read {relative_path}: {e}")
@ -1429,8 +1446,9 @@ async def index_uploaded_files(
final_pages = _compute_final_pages(
page_limit_service, estimated_pages, len(content)
)
final_billable = final_pages * mode.page_multiplier
await page_limit_service.update_page_usage(
user_id, final_pages, allow_exceed=True
user_id, final_billable, allow_exceed=True
)
else:
failed_count += 1

View file

@ -47,6 +47,7 @@ class _ProcessingContext:
connector: dict | None = None
notification: Notification | None = None
use_vision_llm: bool = False
processing_mode: str = "basic"
enable_summary: bool = field(init=False)
def __post_init__(self) -> None:
@ -187,21 +188,28 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No
async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
"""Route a document file to the configured ETL service via the unified pipeline."""
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
from app.services.page_limit_service import PageLimitExceededError, PageLimitService
mode = ProcessingMode.coerce(ctx.processing_mode)
page_limit_service = PageLimitService(ctx.session)
estimated_pages = _estimate_pages_safe(page_limit_service, ctx.file_path)
billable_pages = estimated_pages * mode.page_multiplier
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Estimated {estimated_pages} pages for file: {ctx.filename}",
{"estimated_pages": estimated_pages, "file_type": "document"},
{
"estimated_pages": estimated_pages,
"billable_pages": billable_pages,
"processing_mode": mode.value,
"file_type": "document",
},
)
try:
await page_limit_service.check_page_limit(ctx.user_id, estimated_pages)
await page_limit_service.check_page_limit(ctx.user_id, billable_pages)
except PageLimitExceededError as e:
await ctx.task_logger.log_task_failure(
ctx.log_entry,
@ -212,6 +220,8 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
"pages_used": e.pages_used,
"pages_limit": e.pages_limit,
"estimated_pages": estimated_pages,
"billable_pages": billable_pages,
"processing_mode": mode.value,
},
)
with contextlib.suppress(Exception):
@ -225,6 +235,7 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
file_path=ctx.file_path,
filename=ctx.filename,
estimated_pages=estimated_pages,
processing_mode=mode,
)
)
@ -246,7 +257,7 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
if result:
await page_limit_service.update_page_usage(
ctx.user_id, estimated_pages, allow_exceed=True
ctx.user_id, billable_pages, allow_exceed=True
)
if ctx.connector:
await update_document_from_connector(result, ctx.connector, ctx.session)
@ -259,6 +270,8 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
"file_type": "document",
"etl_service": etl_result.etl_service,
"pages_processed": estimated_pages,
"billable_pages": billable_pages,
"processing_mode": mode.value,
},
)
else:
@ -290,6 +303,7 @@ async def process_file_in_background(
connector: dict | None = None,
notification: Notification | None = None,
use_vision_llm: bool = False,
processing_mode: str = "basic",
) -> Document | None:
ctx = _ProcessingContext(
session=session,
@ -302,6 +316,7 @@ async def process_file_in_background(
connector=connector,
notification=notification,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
try:
@ -353,22 +368,25 @@ async def _extract_file_content(
log_entry: Log,
notification: Notification | None,
use_vision_llm: bool = False,
) -> tuple[str, str]:
processing_mode: str = "basic",
) -> tuple[str, str, int]:
"""
Extract markdown content from a file regardless of type.
Returns:
Tuple of (markdown_content, etl_service_name).
Tuple of (markdown_content, etl_service_name, billable_pages).
"""
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
from app.etl_pipeline.file_classifier import (
FileCategory,
classify_file as etl_classify,
)
mode = ProcessingMode.coerce(processing_mode)
category = etl_classify(filename)
estimated_pages = 0
billable_pages = 0
if notification:
stage_messages = {
@ -397,7 +415,8 @@ async def _extract_file_content(
page_limit_service = PageLimitService(session)
estimated_pages = _estimate_pages_safe(page_limit_service, file_path)
await page_limit_service.check_page_limit(user_id, estimated_pages)
billable_pages = estimated_pages * mode.page_multiplier
await page_limit_service.check_page_limit(user_id, billable_pages)
vision_llm = None
if use_vision_llm and category == FileCategory.IMAGE:
@ -410,21 +429,17 @@ async def _extract_file_content(
file_path=file_path,
filename=filename,
estimated_pages=estimated_pages,
processing_mode=mode,
)
)
if category == FileCategory.DOCUMENT:
await page_limit_service.update_page_usage(
user_id, estimated_pages, allow_exceed=True
)
with contextlib.suppress(Exception):
os.unlink(file_path)
if not result.markdown_content:
raise RuntimeError(f"Failed to extract content from file: {filename}")
return result.markdown_content, result.etl_service
return result.markdown_content, result.etl_service, billable_pages
async def process_file_in_background_with_document(
@ -440,12 +455,16 @@ async def process_file_in_background_with_document(
notification: Notification | None = None,
should_summarize: bool = False,
use_vision_llm: bool = False,
processing_mode: str = "basic",
) -> Document | None:
"""
Process file and update existing pending document (2-phase pattern).
Phase 1 (API layer): Created document with pending status.
Phase 2 (this function): Process file and update document to ready/failed.
Page usage is deferred until after dedup check and successful indexing
to avoid charging for duplicate or failed uploads.
"""
from app.indexing_pipeline.adapters.file_upload_adapter import (
UploadDocumentAdapter,
@ -458,8 +477,7 @@ async def process_file_in_background_with_document(
doc_id = document.id
try:
# Step 1: extract content
markdown_content, etl_service = await _extract_file_content(
markdown_content, etl_service, billable_pages = await _extract_file_content(
file_path,
filename,
search_space_id,
@ -469,12 +487,12 @@ async def process_file_in_background_with_document(
log_entry,
notification,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
if not markdown_content:
raise RuntimeError(f"Failed to extract content from file: {filename}")
# Step 2: duplicate check
content_hash = generate_content_hash(markdown_content, search_space_id)
existing_by_content = await check_duplicate_document(session, content_hash)
if existing_by_content and existing_by_content.id != doc_id:
@ -484,7 +502,6 @@ async def process_file_in_background_with_document(
)
return None
# Step 3: index via pipeline
if notification:
await NotificationService.document_processing.notify_processing_progress(
session,
@ -505,6 +522,14 @@ async def process_file_in_background_with_document(
should_summarize=should_summarize,
)
if billable_pages > 0:
from app.services.page_limit_service import PageLimitService
page_limit_service = PageLimitService(session)
await page_limit_service.update_page_usage(
user_id, billable_pages, allow_exceed=True
)
await task_logger.log_task_success(
log_entry,
f"Successfully processed file: {filename}",
@ -512,6 +537,8 @@ async def process_file_in_background_with_document(
"document_id": doc_id,
"content_hash": content_hash,
"file_type": etl_service,
"billable_pages": billable_pages,
"processing_mode": processing_mode,
},
)
return document

View file

@ -70,6 +70,7 @@ class InlineTaskDispatcher:
user_id: str,
should_summarize: bool = False,
use_vision_llm: bool = False,
processing_mode: str = "basic",
) -> None:
from app.tasks.celery_tasks.document_tasks import (
_process_file_with_document,
@ -84,6 +85,7 @@ class InlineTaskDispatcher:
user_id,
should_summarize=should_summarize,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
@ -321,7 +323,9 @@ def _mock_etl_parsing(monkeypatch):
# -- LlamaParse mock (external API) --------------------------------
async def _fake_llamacloud_parse(file_path: str, estimated_pages: int) -> str:
async def _fake_llamacloud_parse(
file_path: str, estimated_pages: int, processing_mode: str = "basic"
) -> str:
_reject_empty(file_path)
return _MOCK_ETL_MARKDOWN

View file

@ -739,3 +739,187 @@ async def test_extract_image_falls_back_to_document_without_vision_llm(
assert result.markdown_content == "# OCR text from image"
assert result.etl_service == "DOCLING"
assert result.content_type == "document"
# ---------------------------------------------------------------------------
# Processing Mode enum tests
# ---------------------------------------------------------------------------
def test_processing_mode_coerce_basic():
from app.etl_pipeline.etl_document import ProcessingMode
assert ProcessingMode.coerce("basic") == ProcessingMode.BASIC
assert ProcessingMode.coerce("BASIC") == ProcessingMode.BASIC
assert ProcessingMode.coerce(None) == ProcessingMode.BASIC
assert ProcessingMode.coerce("invalid") == ProcessingMode.BASIC
def test_processing_mode_coerce_premium():
from app.etl_pipeline.etl_document import ProcessingMode
assert ProcessingMode.coerce("premium") == ProcessingMode.PREMIUM
assert ProcessingMode.coerce("PREMIUM") == ProcessingMode.PREMIUM
def test_processing_mode_page_multiplier():
from app.etl_pipeline.etl_document import ProcessingMode
assert ProcessingMode.BASIC.page_multiplier == 1
assert ProcessingMode.PREMIUM.page_multiplier == 10
def test_etl_request_default_processing_mode():
from app.etl_pipeline.etl_document import ProcessingMode
req = EtlRequest(file_path="/tmp/test.pdf", filename="test.pdf")
assert req.processing_mode == ProcessingMode.BASIC
def test_etl_request_premium_processing_mode():
from app.etl_pipeline.etl_document import ProcessingMode
req = EtlRequest(
file_path="/tmp/test.pdf",
filename="test.pdf",
processing_mode=ProcessingMode.PREMIUM,
)
assert req.processing_mode == ProcessingMode.PREMIUM
# ---------------------------------------------------------------------------
# Azure DI model selection by processing mode
# ---------------------------------------------------------------------------
async def test_azure_di_basic_uses_prebuilt_read(tmp_path, mocker):
"""Basic mode should use prebuilt-read model for Azure DI."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
mocker.patch(
"app.config.config.AZURE_DI_ENDPOINT",
"https://fake.cognitiveservices.azure.com/",
create=True,
)
mocker.patch("app.config.config.AZURE_DI_KEY", "fake-key", create=True)
fake_client = _mock_azure_di(mocker, "# Azure basic")
_mock_llamacloud(mocker)
from app.etl_pipeline.etl_document import ProcessingMode
result = await EtlPipelineService().extract(
EtlRequest(
file_path=str(pdf_file),
filename="report.pdf",
processing_mode=ProcessingMode.BASIC,
)
)
assert result.markdown_content == "# Azure basic"
call_args = fake_client.begin_analyze_document.call_args
assert call_args[0][0] == "prebuilt-read"
async def test_azure_di_premium_uses_prebuilt_layout(tmp_path, mocker):
"""Premium mode should use prebuilt-layout model for Azure DI."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
mocker.patch(
"app.config.config.AZURE_DI_ENDPOINT",
"https://fake.cognitiveservices.azure.com/",
create=True,
)
mocker.patch("app.config.config.AZURE_DI_KEY", "fake-key", create=True)
fake_client = _mock_azure_di(mocker, "# Azure premium")
_mock_llamacloud(mocker)
from app.etl_pipeline.etl_document import ProcessingMode
result = await EtlPipelineService().extract(
EtlRequest(
file_path=str(pdf_file),
filename="report.pdf",
processing_mode=ProcessingMode.PREMIUM,
)
)
assert result.markdown_content == "# Azure premium"
call_args = fake_client.begin_analyze_document.call_args
assert call_args[0][0] == "prebuilt-layout"
# ---------------------------------------------------------------------------
# LlamaCloud tier selection by processing mode
# ---------------------------------------------------------------------------
async def test_llamacloud_basic_uses_cost_effective_tier(tmp_path, mocker):
"""Basic mode should use cost_effective tier for LlamaCloud."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True)
mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True)
fake_parser = _mock_llamacloud(mocker, "# Llama basic")
llama_parse_cls = mocker.patch(
"llama_cloud_services.LlamaParse", return_value=fake_parser
)
from app.etl_pipeline.etl_document import ProcessingMode
result = await EtlPipelineService().extract(
EtlRequest(
file_path=str(pdf_file),
filename="report.pdf",
estimated_pages=5,
processing_mode=ProcessingMode.BASIC,
)
)
assert result.markdown_content == "# Llama basic"
call_kwargs = llama_parse_cls.call_args[1]
assert call_kwargs["tier"] == "cost_effective"
async def test_llamacloud_premium_uses_agentic_plus_tier(tmp_path, mocker):
"""Premium mode should use agentic_plus tier for LlamaCloud."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True)
mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True)
fake_parser = _mock_llamacloud(mocker, "# Llama premium")
llama_parse_cls = mocker.patch(
"llama_cloud_services.LlamaParse", return_value=fake_parser
)
from app.etl_pipeline.etl_document import ProcessingMode
result = await EtlPipelineService().extract(
EtlRequest(
file_path=str(pdf_file),
filename="report.pdf",
estimated_pages=5,
processing_mode=ProcessingMode.PREMIUM,
)
)
assert result.markdown_content == "# Llama premium"
call_kwargs = llama_parse_cls.call_args[1]
assert call_kwargs["tier"] == "agentic_plus"

View file

@ -0,0 +1,286 @@
"""Unit tests for the structured error response contract.
Validates that:
- Global exception handlers produce the backward-compatible error envelope.
- 5xx responses never leak raw internal exception text.
- X-Request-ID is propagated correctly.
- SurfSenseError, HTTPException, validation, and unhandled exceptions all
use the same response shape.
"""
from __future__ import annotations
import json
import pytest
from fastapi import HTTPException
from starlette.testclient import TestClient
from app.exceptions import (
GENERIC_5XX_MESSAGE,
ISSUES_URL,
ConfigurationError,
ConnectorError,
DatabaseError,
ExternalServiceError,
ForbiddenError,
NotFoundError,
SurfSenseError,
ValidationError,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers - lightweight FastAPI app that re-uses the real global handlers
# ---------------------------------------------------------------------------
def _make_test_app():
"""Build a minimal FastAPI app with the same handlers as the real one."""
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
from app.app import (
RequestIDMiddleware,
_http_exception_handler,
_surfsense_error_handler,
_unhandled_exception_handler,
_validation_error_handler,
)
app = FastAPI()
app.add_middleware(RequestIDMiddleware)
app.add_exception_handler(SurfSenseError, _surfsense_error_handler)
app.add_exception_handler(RequestValidationError, _validation_error_handler)
app.add_exception_handler(HTTPException, _http_exception_handler)
app.add_exception_handler(Exception, _unhandled_exception_handler)
@app.get("/ok")
async def ok():
return {"status": "ok"}
@app.get("/http-400")
async def raise_http_400():
raise HTTPException(status_code=400, detail="Bad input")
@app.get("/http-500")
async def raise_http_500():
raise HTTPException(status_code=500, detail="secret db password leaked")
@app.get("/surfsense-connector")
async def raise_connector():
raise ConnectorError("GitHub API returned 401")
@app.get("/surfsense-notfound")
async def raise_notfound():
raise NotFoundError("Document #42 was not found")
@app.get("/surfsense-forbidden")
async def raise_forbidden():
raise ForbiddenError()
@app.get("/surfsense-config")
async def raise_config():
raise ConfigurationError()
@app.get("/surfsense-db")
async def raise_db():
raise DatabaseError()
@app.get("/surfsense-external")
async def raise_external():
raise ExternalServiceError()
@app.get("/surfsense-validation")
async def raise_validation():
raise ValidationError("Email is invalid")
@app.get("/unhandled")
async def raise_unhandled():
raise RuntimeError("should never reach the client")
class Item(BaseModel):
name: str
count: int
@app.post("/validated")
async def validated(item: Item):
return item.model_dump()
return app
@pytest.fixture(scope="module")
def client():
app = _make_test_app()
return TestClient(app, raise_server_exceptions=False)
# ---------------------------------------------------------------------------
# Envelope shape validation
# ---------------------------------------------------------------------------
def _assert_envelope(resp, expected_status: int):
"""Every error response MUST contain the standard envelope."""
assert resp.status_code == expected_status
body = resp.json()
assert "error" in body, f"Missing 'error' key: {body}"
assert "detail" in body, f"Missing legacy 'detail' key: {body}"
err = body["error"]
assert isinstance(err["code"], str) and len(err["code"]) > 0
assert isinstance(err["message"], str) and len(err["message"]) > 0
assert err["status"] == expected_status
assert isinstance(err["request_id"], str) and len(err["request_id"]) > 0
assert "timestamp" in err
assert err["report_url"] == ISSUES_URL
# Legacy compat: detail mirrors message
assert body["detail"] == err["message"]
return body
# ---------------------------------------------------------------------------
# X-Request-ID propagation
# ---------------------------------------------------------------------------
class TestRequestID:
def test_generated_when_missing(self, client):
resp = client.get("/ok")
assert "X-Request-ID" in resp.headers
assert resp.headers["X-Request-ID"].startswith("req_")
def test_echoed_when_provided(self, client):
resp = client.get("/ok", headers={"X-Request-ID": "my-trace-123"})
assert resp.headers["X-Request-ID"] == "my-trace-123"
def test_present_in_error_response_body(self, client):
resp = client.get("/http-400", headers={"X-Request-ID": "trace-abc"})
body = _assert_envelope(resp, 400)
assert body["error"]["request_id"] == "trace-abc"
# ---------------------------------------------------------------------------
# HTTPException handling
# ---------------------------------------------------------------------------
class TestHTTPExceptionHandler:
def test_400_preserves_detail(self, client):
body = _assert_envelope(client.get("/http-400"), 400)
assert body["error"]["message"] == "Bad input"
assert body["error"]["code"] == "BAD_REQUEST"
def test_500_sanitizes_detail(self, client):
body = _assert_envelope(client.get("/http-500"), 500)
assert "secret" not in body["error"]["message"]
assert "password" not in body["error"]["message"]
assert body["error"]["message"] == GENERIC_5XX_MESSAGE
assert body["error"]["code"] == "INTERNAL_ERROR"
# ---------------------------------------------------------------------------
# SurfSenseError hierarchy
# ---------------------------------------------------------------------------
class TestSurfSenseErrorHandler:
def test_connector_error(self, client):
body = _assert_envelope(client.get("/surfsense-connector"), 502)
assert body["error"]["code"] == "CONNECTOR_ERROR"
assert "GitHub" in body["error"]["message"]
def test_not_found_error(self, client):
body = _assert_envelope(client.get("/surfsense-notfound"), 404)
assert body["error"]["code"] == "NOT_FOUND"
def test_forbidden_error(self, client):
body = _assert_envelope(client.get("/surfsense-forbidden"), 403)
assert body["error"]["code"] == "FORBIDDEN"
def test_configuration_error(self, client):
body = _assert_envelope(client.get("/surfsense-config"), 500)
assert body["error"]["code"] == "CONFIGURATION_ERROR"
def test_database_error(self, client):
body = _assert_envelope(client.get("/surfsense-db"), 500)
assert body["error"]["code"] == "DATABASE_ERROR"
def test_external_service_error(self, client):
body = _assert_envelope(client.get("/surfsense-external"), 502)
assert body["error"]["code"] == "EXTERNAL_SERVICE_ERROR"
def test_validation_error_custom(self, client):
body = _assert_envelope(client.get("/surfsense-validation"), 422)
assert body["error"]["code"] == "VALIDATION_ERROR"
# ---------------------------------------------------------------------------
# Unhandled exception (catch-all)
# ---------------------------------------------------------------------------
class TestUnhandledException:
def test_returns_500_generic_message(self, client):
body = _assert_envelope(client.get("/unhandled"), 500)
assert body["error"]["code"] == "INTERNAL_ERROR"
assert body["error"]["message"] == GENERIC_5XX_MESSAGE
assert "should never reach" not in json.dumps(body)
# ---------------------------------------------------------------------------
# RequestValidationError (pydantic / FastAPI)
# ---------------------------------------------------------------------------
class TestValidationErrorHandler:
def test_missing_fields(self, client):
resp = client.post("/validated", json={})
body = _assert_envelope(resp, 422)
assert body["error"]["code"] == "VALIDATION_ERROR"
assert "required" in body["error"]["message"].lower()
def test_wrong_type(self, client):
resp = client.post("/validated", json={"name": "test", "count": "not-a-number"})
body = _assert_envelope(resp, 422)
assert body["error"]["code"] == "VALIDATION_ERROR"
# ---------------------------------------------------------------------------
# SurfSenseError class hierarchy unit tests
# ---------------------------------------------------------------------------
class TestSurfSenseErrorClasses:
def test_base_defaults(self):
err = SurfSenseError()
assert err.code == "INTERNAL_ERROR"
assert err.status_code == 500
assert err.safe_for_client is True
def test_connector_error(self):
err = ConnectorError("fail")
assert err.code == "CONNECTOR_ERROR"
assert err.status_code == 502
def test_database_error(self):
err = DatabaseError()
assert err.status_code == 500
def test_not_found_error(self):
err = NotFoundError()
assert err.status_code == 404
def test_forbidden_error(self):
err = ForbiddenError()
assert err.status_code == 403
def test_custom_code(self):
err = ConnectorError("x", code="GITHUB_TOKEN_EXPIRED")
assert err.code == "GITHUB_TOKEN_EXPIRED"