mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: add processing mode support for document uploads and ETL pipeline, improded error handling ux
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
- Introduced a `ProcessingMode` enum to differentiate between basic and premium processing modes. - Updated `EtlRequest` to include a `processing_mode` field, defaulting to basic. - Enhanced ETL pipeline services to utilize the selected processing mode for Azure Document Intelligence and LlamaCloud parsing. - Modified various routes and services to handle processing mode, affecting document upload and indexing tasks. - Improved error handling and logging to include processing mode details. - Added tests to validate processing mode functionality and its impact on ETL operations.
This commit is contained in:
parent
b659f41bab
commit
656e061f84
104 changed files with 1900 additions and 909 deletions
|
|
@ -2,12 +2,15 @@ import asyncio
|
|||
import gc
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from threading import Lock
|
||||
|
||||
import redis
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from limits.storage import MemoryStorage
|
||||
|
|
@ -32,6 +35,7 @@ from app.config import (
|
|||
initialize_vision_llm_router,
|
||||
)
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
||||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
|
|
@ -39,6 +43,8 @@ from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
|||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot
|
||||
|
||||
_error_logger = logging.getLogger("surfsense.errors")
|
||||
|
||||
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
|
||||
|
||||
|
||||
|
|
@ -61,13 +67,137 @@ limiter = Limiter(
|
|||
)
|
||||
|
||||
|
||||
def _get_request_id(request: Request) -> str:
|
||||
"""Return the request ID from state, header, or generate a new one."""
|
||||
if hasattr(request.state, "request_id"):
|
||||
return request.state.request_id
|
||||
return request.headers.get("X-Request-ID", f"req_{uuid.uuid4().hex[:12]}")
|
||||
|
||||
|
||||
def _build_error_response(
|
||||
status_code: int,
|
||||
message: str,
|
||||
*,
|
||||
code: str = "INTERNAL_ERROR",
|
||||
request_id: str = "",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> JSONResponse:
|
||||
"""Build the standardized error envelope (new ``error`` + legacy ``detail``)."""
|
||||
body = {
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"status": status_code,
|
||||
"request_id": request_id,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"report_url": ISSUES_URL,
|
||||
},
|
||||
"detail": message,
|
||||
}
|
||||
headers = {"X-Request-ID": request_id}
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
return JSONResponse(status_code=status_code, content=body, headers=headers)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global exception handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _surfsense_error_handler(request: Request, exc: SurfSenseError) -> JSONResponse:
|
||||
"""Handle our own structured exceptions."""
|
||||
rid = _get_request_id(request)
|
||||
if exc.status_code >= 500:
|
||||
_error_logger.error(
|
||||
"[%s] %s - %s: %s",
|
||||
rid,
|
||||
request.url.path,
|
||||
exc.code,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
message = exc.message if exc.safe_for_client else GENERIC_5XX_MESSAGE
|
||||
return _build_error_response(
|
||||
exc.status_code, message, code=exc.code, request_id=rid
|
||||
)
|
||||
|
||||
|
||||
def _http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
"""Wrap FastAPI/Starlette HTTPExceptions into the standard envelope."""
|
||||
rid = _get_request_id(request)
|
||||
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
if exc.status_code >= 500:
|
||||
_error_logger.error(
|
||||
"[%s] %s - HTTPException %d: %s",
|
||||
rid,
|
||||
request.url.path,
|
||||
exc.status_code,
|
||||
detail,
|
||||
)
|
||||
detail = GENERIC_5XX_MESSAGE
|
||||
code = _status_to_code(exc.status_code, detail)
|
||||
return _build_error_response(exc.status_code, detail, code=code, request_id=rid)
|
||||
|
||||
|
||||
def _validation_error_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
"""Return 422 with field-level detail in the standard envelope."""
|
||||
rid = _get_request_id(request)
|
||||
fields = []
|
||||
for err in exc.errors():
|
||||
loc = " -> ".join(str(part) for part in err.get("loc", []))
|
||||
fields.append(f"{loc}: {err.get('msg', 'invalid')}")
|
||||
message = (
|
||||
f"Validation failed: {'; '.join(fields)}" if fields else "Validation failed."
|
||||
)
|
||||
return _build_error_response(422, message, code="VALIDATION_ERROR", request_id=rid)
|
||||
|
||||
|
||||
def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Catch-all: log full traceback, return sanitized 500."""
|
||||
rid = _get_request_id(request)
|
||||
_error_logger.error(
|
||||
"[%s] Unhandled exception on %s %s",
|
||||
rid,
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc_info=True,
|
||||
)
|
||||
return _build_error_response(
|
||||
500, GENERIC_5XX_MESSAGE, code="INTERNAL_ERROR", request_id=rid
|
||||
)
|
||||
|
||||
|
||||
def _status_to_code(status_code: int, detail: str = "") -> str:
|
||||
if detail == "RATE_LIMIT_EXCEEDED":
|
||||
return "RATE_LIMIT_EXCEEDED"
|
||||
mapping = {
|
||||
400: "BAD_REQUEST",
|
||||
401: "UNAUTHORIZED",
|
||||
403: "FORBIDDEN",
|
||||
404: "NOT_FOUND",
|
||||
405: "METHOD_NOT_ALLOWED",
|
||||
409: "CONFLICT",
|
||||
422: "VALIDATION_ERROR",
|
||||
429: "RATE_LIMIT_EXCEEDED",
|
||||
}
|
||||
return mapping.get(
|
||||
status_code, "INTERNAL_ERROR" if status_code >= 500 else "CLIENT_ERROR"
|
||||
)
|
||||
|
||||
|
||||
def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
|
||||
"""Custom 429 handler that returns JSON matching our frontend error format."""
|
||||
"""Custom 429 handler that returns JSON matching our error envelope."""
|
||||
rid = _get_request_id(request)
|
||||
retry_after = exc.detail.split("per")[-1].strip() if exc.detail else "60"
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "RATE_LIMIT_EXCEEDED"},
|
||||
headers={"Retry-After": retry_after},
|
||||
return _build_error_response(
|
||||
429,
|
||||
"Too many requests. Please slow down and try again.",
|
||||
code="RATE_LIMIT_EXCEEDED",
|
||||
request_id=rid,
|
||||
extra_headers={"Retry-After": retry_after},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -258,6 +388,33 @@ app = FastAPI(lifespan=lifespan)
|
|||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Register structured global exception handlers (order matters: most specific first)
|
||||
app.add_exception_handler(SurfSenseError, _surfsense_error_handler)
|
||||
app.add_exception_handler(RequestValidationError, _validation_error_handler)
|
||||
app.add_exception_handler(HTTPException, _http_exception_handler)
|
||||
app.add_exception_handler(Exception, _unhandled_exception_handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request-ID middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""Attach a unique request ID to every request and echo it in the response."""
|
||||
|
||||
async def dispatch(
|
||||
self, request: StarletteRequest, call_next: RequestResponseEndpoint
|
||||
) -> StarletteResponse:
|
||||
request_id = request.headers.get("X-Request-ID", f"req_{uuid.uuid4().hex[:12]}")
|
||||
request.state.request_id = request_id
|
||||
response = await call_next(request)
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request-level performance middleware
|
||||
|
|
|
|||
|
|
@ -1,10 +1,37 @@
|
|||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class ProcessingMode(StrEnum):
|
||||
BASIC = "basic"
|
||||
PREMIUM = "premium"
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, value: str | None) -> "ProcessingMode":
|
||||
if value is None:
|
||||
return cls.BASIC
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
return cls.BASIC
|
||||
|
||||
@property
|
||||
def page_multiplier(self) -> int:
|
||||
return _PAGE_MULTIPLIERS[self]
|
||||
|
||||
|
||||
_PAGE_MULTIPLIERS: dict["ProcessingMode", int] = {
|
||||
ProcessingMode.BASIC: 1,
|
||||
ProcessingMode.PREMIUM: 10,
|
||||
}
|
||||
|
||||
|
||||
class EtlRequest(BaseModel):
|
||||
file_path: str
|
||||
filename: str
|
||||
estimated_pages: int = 0
|
||||
processing_mode: ProcessingMode = ProcessingMode.BASIC
|
||||
|
||||
@field_validator("filename")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -145,13 +145,17 @@ class EtlPipelineService:
|
|||
and getattr(app_config, "AZURE_DI_KEY", None)
|
||||
)
|
||||
|
||||
mode_value = request.processing_mode.value
|
||||
|
||||
if azure_configured and ext in AZURE_DI_DOCUMENT_EXTENSIONS:
|
||||
try:
|
||||
from app.etl_pipeline.parsers.azure_doc_intelligence import (
|
||||
parse_with_azure_doc_intelligence,
|
||||
)
|
||||
|
||||
return await parse_with_azure_doc_intelligence(request.file_path)
|
||||
return await parse_with_azure_doc_intelligence(
|
||||
request.file_path, processing_mode=mode_value
|
||||
)
|
||||
except Exception:
|
||||
logging.warning(
|
||||
"Azure Document Intelligence failed for %s, "
|
||||
|
|
@ -162,4 +166,6 @@ class EtlPipelineService:
|
|||
|
||||
from app.etl_pipeline.parsers.llamacloud import parse_with_llamacloud
|
||||
|
||||
return await parse_with_llamacloud(request.file_path, request.estimated_pages)
|
||||
return await parse_with_llamacloud(
|
||||
request.file_path, request.estimated_pages, processing_mode=mode_value
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,15 @@ BASE_DELAY = 10
|
|||
MAX_DELAY = 120
|
||||
|
||||
|
||||
async def parse_with_azure_doc_intelligence(file_path: str) -> str:
|
||||
AZURE_MODEL_BY_MODE = {
|
||||
"basic": "prebuilt-read",
|
||||
"premium": "prebuilt-layout",
|
||||
}
|
||||
|
||||
|
||||
async def parse_with_azure_doc_intelligence(
|
||||
file_path: str, processing_mode: str = "basic"
|
||||
) -> str:
|
||||
from azure.ai.documentintelligence.aio import DocumentIntelligenceClient
|
||||
from azure.ai.documentintelligence.models import DocumentContentFormat
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
|
@ -21,9 +29,15 @@ async def parse_with_azure_doc_intelligence(file_path: str) -> str:
|
|||
ServiceResponseError,
|
||||
)
|
||||
|
||||
model_id = AZURE_MODEL_BY_MODE.get(processing_mode, "prebuilt-read")
|
||||
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
|
||||
retryable_exceptions = (ServiceRequestError, ServiceResponseError)
|
||||
|
||||
logging.info(
|
||||
f"Azure Document Intelligence using model={model_id} "
|
||||
f"(mode={processing_mode}, file={file_size_mb:.1f}MB)"
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
attempt_errors: list[str] = []
|
||||
|
||||
|
|
@ -36,7 +50,7 @@ async def parse_with_azure_doc_intelligence(file_path: str) -> str:
|
|||
async with client:
|
||||
with open(file_path, "rb") as f:
|
||||
poller = await client.begin_analyze_document(
|
||||
"prebuilt-layout",
|
||||
model_id,
|
||||
body=f,
|
||||
output_content_format=DocumentContentFormat.MARKDOWN,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,8 +16,15 @@ from app.etl_pipeline.constants import (
|
|||
calculate_upload_timeout,
|
||||
)
|
||||
|
||||
LLAMA_TIER_BY_MODE = {
|
||||
"basic": "cost_effective",
|
||||
"premium": "agentic_plus",
|
||||
}
|
||||
|
||||
async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
|
||||
|
||||
async def parse_with_llamacloud(
|
||||
file_path: str, estimated_pages: int, processing_mode: str = "basic"
|
||||
) -> str:
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
|
|
@ -34,10 +41,12 @@ async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
|
|||
pool=120.0,
|
||||
)
|
||||
|
||||
tier = LLAMA_TIER_BY_MODE.get(processing_mode, "cost_effective")
|
||||
|
||||
logging.info(
|
||||
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
|
||||
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
|
||||
f"job_timeout={job_timeout:.0f}s"
|
||||
f"job_timeout={job_timeout:.0f}s, tier={tier} (mode={processing_mode})"
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
|
|
@ -56,6 +65,7 @@ async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
|
|||
job_timeout_in_seconds=job_timeout,
|
||||
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
|
||||
custom_client=custom_client,
|
||||
tier=tier,
|
||||
)
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
|
|
|
|||
104
surfsense_backend/app/exceptions.py
Normal file
104
surfsense_backend/app/exceptions.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Structured error hierarchy for SurfSense.
|
||||
|
||||
Every error response follows a backward-compatible contract:
|
||||
|
||||
{
|
||||
"error": {
|
||||
"code": "SOME_ERROR_CODE",
|
||||
"message": "Human-readable, client-safe message.",
|
||||
"status": 422,
|
||||
"request_id": "req_...",
|
||||
"timestamp": "2026-04-14T12:00:00Z",
|
||||
"report_url": "https://github.com/MODSetter/SurfSense/issues"
|
||||
},
|
||||
"detail": "Human-readable, client-safe message." # legacy compat
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
ISSUES_URL = "https://github.com/MODSetter/SurfSense/issues"
|
||||
|
||||
GENERIC_5XX_MESSAGE = (
|
||||
"An internal error occurred. Please try again or report this issue if it persists."
|
||||
)
|
||||
|
||||
|
||||
class SurfSenseError(Exception):
|
||||
"""Base exception that global handlers translate into the structured envelope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = GENERIC_5XX_MESSAGE,
|
||||
*,
|
||||
code: str = "INTERNAL_ERROR",
|
||||
status_code: int = 500,
|
||||
safe_for_client: bool = True,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
self.safe_for_client = safe_for_client
|
||||
|
||||
|
||||
class ConnectorError(SurfSenseError):
|
||||
def __init__(self, message: str, *, code: str = "CONNECTOR_ERROR") -> None:
|
||||
super().__init__(message, code=code, status_code=502)
|
||||
|
||||
|
||||
class DatabaseError(SurfSenseError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "A database error occurred.",
|
||||
*,
|
||||
code: str = "DATABASE_ERROR",
|
||||
) -> None:
|
||||
super().__init__(message, code=code, status_code=500)
|
||||
|
||||
|
||||
class ConfigurationError(SurfSenseError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "A configuration error occurred.",
|
||||
*,
|
||||
code: str = "CONFIGURATION_ERROR",
|
||||
) -> None:
|
||||
super().__init__(message, code=code, status_code=500)
|
||||
|
||||
|
||||
class ExternalServiceError(SurfSenseError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "An external service is unavailable.",
|
||||
*,
|
||||
code: str = "EXTERNAL_SERVICE_ERROR",
|
||||
) -> None:
|
||||
super().__init__(message, code=code, status_code=502)
|
||||
|
||||
|
||||
class NotFoundError(SurfSenseError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "The requested resource was not found.",
|
||||
*,
|
||||
code: str = "NOT_FOUND",
|
||||
) -> None:
|
||||
super().__init__(message, code=code, status_code=404)
|
||||
|
||||
|
||||
class ForbiddenError(SurfSenseError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "You don't have permission to access this resource.",
|
||||
*,
|
||||
code: str = "FORBIDDEN",
|
||||
) -> None:
|
||||
super().__init__(message, code=code, status_code=403)
|
||||
|
||||
|
||||
class ValidationError(SurfSenseError):
|
||||
def __init__(
|
||||
self, message: str = "Validation failed.", *, code: str = "VALIDATION_ERROR"
|
||||
) -> None:
|
||||
super().__init__(message, code=code, status_code=422)
|
||||
|
|
@ -124,6 +124,7 @@ async def create_documents_file_upload(
|
|||
search_space_id: int = Form(...),
|
||||
should_summarize: bool = Form(False),
|
||||
use_vision_llm: bool = Form(False),
|
||||
processing_mode: str = Form("basic"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
|
||||
|
|
@ -142,12 +143,15 @@ async def create_documents_file_upload(
|
|||
from datetime import datetime
|
||||
|
||||
from app.db import DocumentStatus
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
from app.tasks.document_processors.base import (
|
||||
check_document_by_unique_identifier,
|
||||
get_current_timestamp,
|
||||
)
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
validated_mode = ProcessingMode.coerce(processing_mode)
|
||||
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -274,6 +278,7 @@ async def create_documents_file_upload(
|
|||
user_id=str(user.id),
|
||||
should_summarize=should_summarize,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=validated_mode.value,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -1493,6 +1498,7 @@ async def folder_upload(
|
|||
root_folder_id: int | None = Form(None),
|
||||
enable_summary: bool = Form(False),
|
||||
use_vision_llm: bool = Form(False),
|
||||
processing_mode: str = Form("basic"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1504,6 +1510,10 @@ async def folder_upload(
|
|||
import json
|
||||
import tempfile
|
||||
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
validated_mode = ProcessingMode.coerce(processing_mode)
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
|
|
@ -1558,6 +1568,7 @@ async def folder_upload(
|
|||
watched_metadata = {
|
||||
"watched": True,
|
||||
"folder_path": folder_name,
|
||||
"processing_mode": validated_mode.value,
|
||||
}
|
||||
existing_root = (
|
||||
await session.execute(
|
||||
|
|
@ -1621,6 +1632,7 @@ async def folder_upload(
|
|||
enable_summary=enable_summary,
|
||||
use_vision_llm=use_vision_llm,
|
||||
file_mappings=list(file_mappings),
|
||||
processing_mode=validated_mode.value,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class TaskDispatcher(Protocol):
|
|||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
) -> None: ...
|
||||
|
||||
|
||||
|
|
@ -36,6 +37,7 @@ class CeleryTaskDispatcher:
|
|||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
) -> None:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
process_file_upload_with_document_task,
|
||||
|
|
@ -49,6 +51,7 @@ class CeleryTaskDispatcher:
|
|||
user_id=user_id,
|
||||
should_summarize=should_summarize,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -780,6 +780,7 @@ def process_file_upload_with_document_task(
|
|||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
):
|
||||
"""
|
||||
Celery task to process uploaded file with existing pending document.
|
||||
|
|
@ -836,6 +837,7 @@ def process_file_upload_with_document_task(
|
|||
user_id,
|
||||
should_summarize=should_summarize,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
|
|
@ -873,6 +875,7 @@ async def _process_file_with_document(
|
|||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
):
|
||||
"""
|
||||
Process file and update existing pending document status.
|
||||
|
|
@ -976,6 +979,7 @@ async def _process_file_with_document(
|
|||
notification=notification,
|
||||
should_summarize=should_summarize,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
|
||||
# Update notification on success
|
||||
|
|
@ -1434,6 +1438,7 @@ def index_uploaded_folder_files_task(
|
|||
enable_summary: bool,
|
||||
file_mappings: list[dict],
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
):
|
||||
"""Celery task to index files uploaded from the desktop app."""
|
||||
loop = asyncio.new_event_loop()
|
||||
|
|
@ -1448,6 +1453,7 @@ def index_uploaded_folder_files_task(
|
|||
enable_summary=enable_summary,
|
||||
file_mappings=file_mappings,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
|
|
@ -1462,6 +1468,7 @@ async def _index_uploaded_folder_files_async(
|
|||
enable_summary: bool,
|
||||
file_mappings: list[dict],
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
):
|
||||
"""Run upload-based folder indexing with notification + heartbeat."""
|
||||
file_count = len(file_mappings)
|
||||
|
|
@ -1512,6 +1519,7 @@ async def _index_uploaded_folder_files_async(
|
|||
file_mappings=file_mappings,
|
||||
on_heartbeat_callback=_heartbeat_progress,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
|
||||
if notification:
|
||||
|
|
|
|||
|
|
@ -60,14 +60,16 @@ async def _check_page_limit_or_skip(
|
|||
page_limit_service: PageLimitService,
|
||||
user_id: str,
|
||||
file_path: str,
|
||||
) -> int:
|
||||
page_multiplier: int = 1,
|
||||
) -> tuple[int, int]:
|
||||
"""Estimate pages and check the limit; raises PageLimitExceededError if over quota.
|
||||
|
||||
Returns the estimated page count on success.
|
||||
Returns (estimated_pages, billable_pages).
|
||||
"""
|
||||
estimated = _estimate_pages_safe(page_limit_service, file_path)
|
||||
await page_limit_service.check_page_limit(user_id, estimated)
|
||||
return estimated
|
||||
billable = estimated * page_multiplier
|
||||
await page_limit_service.check_page_limit(user_id, billable)
|
||||
return estimated, billable
|
||||
|
||||
|
||||
def _compute_final_pages(
|
||||
|
|
@ -153,17 +155,20 @@ def scan_folder(
|
|||
return files
|
||||
|
||||
|
||||
async def _read_file_content(file_path: str, filename: str, *, vision_llm=None) -> str:
|
||||
async def _read_file_content(
|
||||
file_path: str, filename: str, *, vision_llm=None, processing_mode: str = "basic"
|
||||
) -> str:
|
||||
"""Read file content via the unified ETL pipeline.
|
||||
|
||||
All file types (plaintext, audio, direct-convert, document, image) are
|
||||
handled by ``EtlPipelineService``.
|
||||
"""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
mode = ProcessingMode.coerce(processing_mode)
|
||||
result = await EtlPipelineService(vision_llm=vision_llm).extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
EtlRequest(file_path=file_path, filename=filename, processing_mode=mode)
|
||||
)
|
||||
return result.markdown_content
|
||||
|
||||
|
|
@ -201,12 +206,15 @@ async def _compute_file_content_hash(
|
|||
search_space_id: int,
|
||||
*,
|
||||
vision_llm=None,
|
||||
processing_mode: str = "basic",
|
||||
) -> tuple[str, str]:
|
||||
"""Read a file (via ETL if needed) and compute its content hash.
|
||||
|
||||
Returns (content_text, content_hash).
|
||||
"""
|
||||
content = await _read_file_content(file_path, filename, vision_llm=vision_llm)
|
||||
content = await _read_file_content(
|
||||
file_path, filename, vision_llm=vision_llm, processing_mode=processing_mode
|
||||
)
|
||||
return content, _content_hash(content, search_space_id)
|
||||
|
||||
|
||||
|
|
@ -694,7 +702,7 @@ async def index_local_folder(
|
|||
continue
|
||||
|
||||
try:
|
||||
estimated_pages = await _check_page_limit_or_skip(
|
||||
estimated_pages, _billable = await _check_page_limit_or_skip(
|
||||
page_limit_service, user_id, file_path_abs
|
||||
)
|
||||
except PageLimitExceededError:
|
||||
|
|
@ -730,7 +738,7 @@ async def index_local_folder(
|
|||
await create_version_snapshot(session, existing_document)
|
||||
else:
|
||||
try:
|
||||
estimated_pages = await _check_page_limit_or_skip(
|
||||
estimated_pages, _billable = await _check_page_limit_or_skip(
|
||||
page_limit_service, user_id, file_path_abs
|
||||
)
|
||||
except PageLimitExceededError:
|
||||
|
|
@ -1080,7 +1088,7 @@ async def _index_single_file(
|
|||
|
||||
page_limit_service = PageLimitService(session)
|
||||
try:
|
||||
estimated_pages = await _check_page_limit_or_skip(
|
||||
estimated_pages, _billable = await _check_page_limit_or_skip(
|
||||
page_limit_service, user_id, str(full_path)
|
||||
)
|
||||
except PageLimitExceededError as e:
|
||||
|
|
@ -1271,6 +1279,7 @@ async def index_uploaded_files(
|
|||
file_mappings: list[dict],
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Index files uploaded from the desktop app via temp paths.
|
||||
|
||||
|
|
@ -1281,12 +1290,16 @@ async def index_uploaded_files(
|
|||
|
||||
Returns ``(indexed_count, failed_count, error_summary_or_none)``.
|
||||
"""
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
mode = ProcessingMode.coerce(processing_mode)
|
||||
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="local_folder_indexing",
|
||||
source="uploaded_folder_indexing",
|
||||
message=f"Indexing {len(file_mappings)} uploaded file(s) for {folder_name}",
|
||||
metadata={"file_count": len(file_mappings)},
|
||||
metadata={"file_count": len(file_mappings), "processing_mode": mode.value},
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -1350,8 +1363,11 @@ async def index_uploaded_files(
|
|||
continue
|
||||
|
||||
try:
|
||||
estimated_pages = await _check_page_limit_or_skip(
|
||||
page_limit_service, user_id, temp_path
|
||||
estimated_pages, _billable_pages = await _check_page_limit_or_skip(
|
||||
page_limit_service,
|
||||
user_id,
|
||||
temp_path,
|
||||
page_multiplier=mode.page_multiplier,
|
||||
)
|
||||
except PageLimitExceededError:
|
||||
logger.warning(f"Page limit exceeded, skipping: {relative_path}")
|
||||
|
|
@ -1364,6 +1380,7 @@ async def index_uploaded_files(
|
|||
filename,
|
||||
search_space_id,
|
||||
vision_llm=vision_llm_instance,
|
||||
processing_mode=mode.value,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read {relative_path}: {e}")
|
||||
|
|
@ -1429,8 +1446,9 @@ async def index_uploaded_files(
|
|||
final_pages = _compute_final_pages(
|
||||
page_limit_service, estimated_pages, len(content)
|
||||
)
|
||||
final_billable = final_pages * mode.page_multiplier
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, final_pages, allow_exceed=True
|
||||
user_id, final_billable, allow_exceed=True
|
||||
)
|
||||
else:
|
||||
failed_count += 1
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class _ProcessingContext:
|
|||
connector: dict | None = None
|
||||
notification: Notification | None = None
|
||||
use_vision_llm: bool = False
|
||||
processing_mode: str = "basic"
|
||||
enable_summary: bool = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
|
@ -187,21 +188,28 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No
|
|||
|
||||
async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Route a document file to the configured ETL service via the unified pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.services.page_limit_service import PageLimitExceededError, PageLimitService
|
||||
|
||||
mode = ProcessingMode.coerce(ctx.processing_mode)
|
||||
page_limit_service = PageLimitService(ctx.session)
|
||||
estimated_pages = _estimate_pages_safe(page_limit_service, ctx.file_path)
|
||||
billable_pages = estimated_pages * mode.page_multiplier
|
||||
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Estimated {estimated_pages} pages for file: {ctx.filename}",
|
||||
{"estimated_pages": estimated_pages, "file_type": "document"},
|
||||
{
|
||||
"estimated_pages": estimated_pages,
|
||||
"billable_pages": billable_pages,
|
||||
"processing_mode": mode.value,
|
||||
"file_type": "document",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
await page_limit_service.check_page_limit(ctx.user_id, estimated_pages)
|
||||
await page_limit_service.check_page_limit(ctx.user_id, billable_pages)
|
||||
except PageLimitExceededError as e:
|
||||
await ctx.task_logger.log_task_failure(
|
||||
ctx.log_entry,
|
||||
|
|
@ -212,6 +220,8 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
"pages_used": e.pages_used,
|
||||
"pages_limit": e.pages_limit,
|
||||
"estimated_pages": estimated_pages,
|
||||
"billable_pages": billable_pages,
|
||||
"processing_mode": mode.value,
|
||||
},
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
|
|
@ -225,6 +235,7 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
file_path=ctx.file_path,
|
||||
filename=ctx.filename,
|
||||
estimated_pages=estimated_pages,
|
||||
processing_mode=mode,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -246,7 +257,7 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
|
||||
if result:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, estimated_pages, allow_exceed=True
|
||||
ctx.user_id, billable_pages, allow_exceed=True
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
|
|
@ -259,6 +270,8 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
"file_type": "document",
|
||||
"etl_service": etl_result.etl_service,
|
||||
"pages_processed": estimated_pages,
|
||||
"billable_pages": billable_pages,
|
||||
"processing_mode": mode.value,
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
|
@ -290,6 +303,7 @@ async def process_file_in_background(
|
|||
connector: dict | None = None,
|
||||
notification: Notification | None = None,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
) -> Document | None:
|
||||
ctx = _ProcessingContext(
|
||||
session=session,
|
||||
|
|
@ -302,6 +316,7 @@ async def process_file_in_background(
|
|||
connector=connector,
|
||||
notification=notification,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -353,22 +368,25 @@ async def _extract_file_content(
|
|||
log_entry: Log,
|
||||
notification: Notification | None,
|
||||
use_vision_llm: bool = False,
|
||||
) -> tuple[str, str]:
|
||||
processing_mode: str = "basic",
|
||||
) -> tuple[str, str, int]:
|
||||
"""
|
||||
Extract markdown content from a file regardless of type.
|
||||
|
||||
Returns:
|
||||
Tuple of (markdown_content, etl_service_name).
|
||||
Tuple of (markdown_content, etl_service_name, billable_pages).
|
||||
"""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.etl_pipeline.file_classifier import (
|
||||
FileCategory,
|
||||
classify_file as etl_classify,
|
||||
)
|
||||
|
||||
mode = ProcessingMode.coerce(processing_mode)
|
||||
category = etl_classify(filename)
|
||||
estimated_pages = 0
|
||||
billable_pages = 0
|
||||
|
||||
if notification:
|
||||
stage_messages = {
|
||||
|
|
@ -397,7 +415,8 @@ async def _extract_file_content(
|
|||
|
||||
page_limit_service = PageLimitService(session)
|
||||
estimated_pages = _estimate_pages_safe(page_limit_service, file_path)
|
||||
await page_limit_service.check_page_limit(user_id, estimated_pages)
|
||||
billable_pages = estimated_pages * mode.page_multiplier
|
||||
await page_limit_service.check_page_limit(user_id, billable_pages)
|
||||
|
||||
vision_llm = None
|
||||
if use_vision_llm and category == FileCategory.IMAGE:
|
||||
|
|
@ -410,21 +429,17 @@ async def _extract_file_content(
|
|||
file_path=file_path,
|
||||
filename=filename,
|
||||
estimated_pages=estimated_pages,
|
||||
processing_mode=mode,
|
||||
)
|
||||
)
|
||||
|
||||
if category == FileCategory.DOCUMENT:
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
|
||||
if not result.markdown_content:
|
||||
raise RuntimeError(f"Failed to extract content from file: {filename}")
|
||||
|
||||
return result.markdown_content, result.etl_service
|
||||
return result.markdown_content, result.etl_service, billable_pages
|
||||
|
||||
|
||||
async def process_file_in_background_with_document(
|
||||
|
|
@ -440,12 +455,16 @@ async def process_file_in_background_with_document(
|
|||
notification: Notification | None = None,
|
||||
should_summarize: bool = False,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process file and update existing pending document (2-phase pattern).
|
||||
|
||||
Phase 1 (API layer): Created document with pending status.
|
||||
Phase 2 (this function): Process file and update document to ready/failed.
|
||||
|
||||
Page usage is deferred until after dedup check and successful indexing
|
||||
to avoid charging for duplicate or failed uploads.
|
||||
"""
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import (
|
||||
UploadDocumentAdapter,
|
||||
|
|
@ -458,8 +477,7 @@ async def process_file_in_background_with_document(
|
|||
doc_id = document.id
|
||||
|
||||
try:
|
||||
# Step 1: extract content
|
||||
markdown_content, etl_service = await _extract_file_content(
|
||||
markdown_content, etl_service, billable_pages = await _extract_file_content(
|
||||
file_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
|
|
@ -469,12 +487,12 @@ async def process_file_in_background_with_document(
|
|||
log_entry,
|
||||
notification,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
|
||||
if not markdown_content:
|
||||
raise RuntimeError(f"Failed to extract content from file: {filename}")
|
||||
|
||||
# Step 2: duplicate check
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
existing_by_content = await check_duplicate_document(session, content_hash)
|
||||
if existing_by_content and existing_by_content.id != doc_id:
|
||||
|
|
@ -484,7 +502,6 @@ async def process_file_in_background_with_document(
|
|||
)
|
||||
return None
|
||||
|
||||
# Step 3: index via pipeline
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
|
|
@ -505,6 +522,14 @@ async def process_file_in_background_with_document(
|
|||
should_summarize=should_summarize,
|
||||
)
|
||||
|
||||
if billable_pages > 0:
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, billable_pages, allow_exceed=True
|
||||
)
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully processed file: {filename}",
|
||||
|
|
@ -512,6 +537,8 @@ async def process_file_in_background_with_document(
|
|||
"document_id": doc_id,
|
||||
"content_hash": content_hash,
|
||||
"file_type": etl_service,
|
||||
"billable_pages": billable_pages,
|
||||
"processing_mode": processing_mode,
|
||||
},
|
||||
)
|
||||
return document
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ class InlineTaskDispatcher:
|
|||
user_id: str,
|
||||
should_summarize: bool = False,
|
||||
use_vision_llm: bool = False,
|
||||
processing_mode: str = "basic",
|
||||
) -> None:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
_process_file_with_document,
|
||||
|
|
@ -84,6 +85,7 @@ class InlineTaskDispatcher:
|
|||
user_id,
|
||||
should_summarize=should_summarize,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -321,7 +323,9 @@ def _mock_etl_parsing(monkeypatch):
|
|||
|
||||
# -- LlamaParse mock (external API) --------------------------------
|
||||
|
||||
async def _fake_llamacloud_parse(file_path: str, estimated_pages: int) -> str:
|
||||
async def _fake_llamacloud_parse(
|
||||
file_path: str, estimated_pages: int, processing_mode: str = "basic"
|
||||
) -> str:
|
||||
_reject_empty(file_path)
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
|
|
|
|||
|
|
@ -739,3 +739,187 @@ async def test_extract_image_falls_back_to_document_without_vision_llm(
|
|||
assert result.markdown_content == "# OCR text from image"
|
||||
assert result.etl_service == "DOCLING"
|
||||
assert result.content_type == "document"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Processing Mode enum tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_processing_mode_coerce_basic():
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
assert ProcessingMode.coerce("basic") == ProcessingMode.BASIC
|
||||
assert ProcessingMode.coerce("BASIC") == ProcessingMode.BASIC
|
||||
assert ProcessingMode.coerce(None) == ProcessingMode.BASIC
|
||||
assert ProcessingMode.coerce("invalid") == ProcessingMode.BASIC
|
||||
|
||||
|
||||
def test_processing_mode_coerce_premium():
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
assert ProcessingMode.coerce("premium") == ProcessingMode.PREMIUM
|
||||
assert ProcessingMode.coerce("PREMIUM") == ProcessingMode.PREMIUM
|
||||
|
||||
|
||||
def test_processing_mode_page_multiplier():
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
assert ProcessingMode.BASIC.page_multiplier == 1
|
||||
assert ProcessingMode.PREMIUM.page_multiplier == 10
|
||||
|
||||
|
||||
def test_etl_request_default_processing_mode():
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
req = EtlRequest(file_path="/tmp/test.pdf", filename="test.pdf")
|
||||
assert req.processing_mode == ProcessingMode.BASIC
|
||||
|
||||
|
||||
def test_etl_request_premium_processing_mode():
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
req = EtlRequest(
|
||||
file_path="/tmp/test.pdf",
|
||||
filename="test.pdf",
|
||||
processing_mode=ProcessingMode.PREMIUM,
|
||||
)
|
||||
assert req.processing_mode == ProcessingMode.PREMIUM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Azure DI model selection by processing mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_azure_di_basic_uses_prebuilt_read(tmp_path, mocker):
|
||||
"""Basic mode should use prebuilt-read model for Azure DI."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
|
||||
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
|
||||
mocker.patch(
|
||||
"app.config.config.AZURE_DI_ENDPOINT",
|
||||
"https://fake.cognitiveservices.azure.com/",
|
||||
create=True,
|
||||
)
|
||||
mocker.patch("app.config.config.AZURE_DI_KEY", "fake-key", create=True)
|
||||
|
||||
fake_client = _mock_azure_di(mocker, "# Azure basic")
|
||||
_mock_llamacloud(mocker)
|
||||
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(
|
||||
file_path=str(pdf_file),
|
||||
filename="report.pdf",
|
||||
processing_mode=ProcessingMode.BASIC,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Azure basic"
|
||||
call_args = fake_client.begin_analyze_document.call_args
|
||||
assert call_args[0][0] == "prebuilt-read"
|
||||
|
||||
|
||||
async def test_azure_di_premium_uses_prebuilt_layout(tmp_path, mocker):
|
||||
"""Premium mode should use prebuilt-layout model for Azure DI."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
|
||||
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
|
||||
mocker.patch(
|
||||
"app.config.config.AZURE_DI_ENDPOINT",
|
||||
"https://fake.cognitiveservices.azure.com/",
|
||||
create=True,
|
||||
)
|
||||
mocker.patch("app.config.config.AZURE_DI_KEY", "fake-key", create=True)
|
||||
|
||||
fake_client = _mock_azure_di(mocker, "# Azure premium")
|
||||
_mock_llamacloud(mocker)
|
||||
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(
|
||||
file_path=str(pdf_file),
|
||||
filename="report.pdf",
|
||||
processing_mode=ProcessingMode.PREMIUM,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Azure premium"
|
||||
call_args = fake_client.begin_analyze_document.call_args
|
||||
assert call_args[0][0] == "prebuilt-layout"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LlamaCloud tier selection by processing mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_llamacloud_basic_uses_cost_effective_tier(tmp_path, mocker):
|
||||
"""Basic mode should use cost_effective tier for LlamaCloud."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
|
||||
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
|
||||
mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True)
|
||||
mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True)
|
||||
|
||||
fake_parser = _mock_llamacloud(mocker, "# Llama basic")
|
||||
|
||||
llama_parse_cls = mocker.patch(
|
||||
"llama_cloud_services.LlamaParse", return_value=fake_parser
|
||||
)
|
||||
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(
|
||||
file_path=str(pdf_file),
|
||||
filename="report.pdf",
|
||||
estimated_pages=5,
|
||||
processing_mode=ProcessingMode.BASIC,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Llama basic"
|
||||
call_kwargs = llama_parse_cls.call_args[1]
|
||||
assert call_kwargs["tier"] == "cost_effective"
|
||||
|
||||
|
||||
async def test_llamacloud_premium_uses_agentic_plus_tier(tmp_path, mocker):
|
||||
"""Premium mode should use agentic_plus tier for LlamaCloud."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
|
||||
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
|
||||
mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True)
|
||||
mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True)
|
||||
|
||||
fake_parser = _mock_llamacloud(mocker, "# Llama premium")
|
||||
|
||||
llama_parse_cls = mocker.patch(
|
||||
"llama_cloud_services.LlamaParse", return_value=fake_parser
|
||||
)
|
||||
|
||||
from app.etl_pipeline.etl_document import ProcessingMode
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(
|
||||
file_path=str(pdf_file),
|
||||
filename="report.pdf",
|
||||
estimated_pages=5,
|
||||
processing_mode=ProcessingMode.PREMIUM,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Llama premium"
|
||||
call_kwargs = llama_parse_cls.call_args[1]
|
||||
assert call_kwargs["tier"] == "agentic_plus"
|
||||
|
|
|
|||
286
surfsense_backend/tests/unit/test_error_contract.py
Normal file
286
surfsense_backend/tests/unit/test_error_contract.py
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
"""Unit tests for the structured error response contract.
|
||||
|
||||
Validates that:
|
||||
- Global exception handlers produce the backward-compatible error envelope.
|
||||
- 5xx responses never leak raw internal exception text.
|
||||
- X-Request-ID is propagated correctly.
|
||||
- SurfSenseError, HTTPException, validation, and unhandled exceptions all
|
||||
use the same response shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.exceptions import (
|
||||
GENERIC_5XX_MESSAGE,
|
||||
ISSUES_URL,
|
||||
ConfigurationError,
|
||||
ConnectorError,
|
||||
DatabaseError,
|
||||
ExternalServiceError,
|
||||
ForbiddenError,
|
||||
NotFoundError,
|
||||
SurfSenseError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers - lightweight FastAPI app that re-uses the real global handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_test_app():
|
||||
"""Build a minimal FastAPI app with the same handlers as the real one."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.app import (
|
||||
RequestIDMiddleware,
|
||||
_http_exception_handler,
|
||||
_surfsense_error_handler,
|
||||
_unhandled_exception_handler,
|
||||
_validation_error_handler,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
app.add_exception_handler(SurfSenseError, _surfsense_error_handler)
|
||||
app.add_exception_handler(RequestValidationError, _validation_error_handler)
|
||||
app.add_exception_handler(HTTPException, _http_exception_handler)
|
||||
app.add_exception_handler(Exception, _unhandled_exception_handler)
|
||||
|
||||
@app.get("/ok")
|
||||
async def ok():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/http-400")
|
||||
async def raise_http_400():
|
||||
raise HTTPException(status_code=400, detail="Bad input")
|
||||
|
||||
@app.get("/http-500")
|
||||
async def raise_http_500():
|
||||
raise HTTPException(status_code=500, detail="secret db password leaked")
|
||||
|
||||
@app.get("/surfsense-connector")
|
||||
async def raise_connector():
|
||||
raise ConnectorError("GitHub API returned 401")
|
||||
|
||||
@app.get("/surfsense-notfound")
|
||||
async def raise_notfound():
|
||||
raise NotFoundError("Document #42 was not found")
|
||||
|
||||
@app.get("/surfsense-forbidden")
|
||||
async def raise_forbidden():
|
||||
raise ForbiddenError()
|
||||
|
||||
@app.get("/surfsense-config")
|
||||
async def raise_config():
|
||||
raise ConfigurationError()
|
||||
|
||||
@app.get("/surfsense-db")
|
||||
async def raise_db():
|
||||
raise DatabaseError()
|
||||
|
||||
@app.get("/surfsense-external")
|
||||
async def raise_external():
|
||||
raise ExternalServiceError()
|
||||
|
||||
@app.get("/surfsense-validation")
|
||||
async def raise_validation():
|
||||
raise ValidationError("Email is invalid")
|
||||
|
||||
@app.get("/unhandled")
|
||||
async def raise_unhandled():
|
||||
raise RuntimeError("should never reach the client")
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
|
||||
@app.post("/validated")
|
||||
async def validated(item: Item):
|
||||
return item.model_dump()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
app = _make_test_app()
|
||||
return TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Envelope shape validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assert_envelope(resp, expected_status: int):
|
||||
"""Every error response MUST contain the standard envelope."""
|
||||
assert resp.status_code == expected_status
|
||||
body = resp.json()
|
||||
assert "error" in body, f"Missing 'error' key: {body}"
|
||||
assert "detail" in body, f"Missing legacy 'detail' key: {body}"
|
||||
|
||||
err = body["error"]
|
||||
assert isinstance(err["code"], str) and len(err["code"]) > 0
|
||||
assert isinstance(err["message"], str) and len(err["message"]) > 0
|
||||
assert err["status"] == expected_status
|
||||
assert isinstance(err["request_id"], str) and len(err["request_id"]) > 0
|
||||
assert "timestamp" in err
|
||||
assert err["report_url"] == ISSUES_URL
|
||||
|
||||
# Legacy compat: detail mirrors message
|
||||
assert body["detail"] == err["message"]
|
||||
|
||||
return body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# X-Request-ID propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRequestID:
|
||||
def test_generated_when_missing(self, client):
|
||||
resp = client.get("/ok")
|
||||
assert "X-Request-ID" in resp.headers
|
||||
assert resp.headers["X-Request-ID"].startswith("req_")
|
||||
|
||||
def test_echoed_when_provided(self, client):
|
||||
resp = client.get("/ok", headers={"X-Request-ID": "my-trace-123"})
|
||||
assert resp.headers["X-Request-ID"] == "my-trace-123"
|
||||
|
||||
def test_present_in_error_response_body(self, client):
|
||||
resp = client.get("/http-400", headers={"X-Request-ID": "trace-abc"})
|
||||
body = _assert_envelope(resp, 400)
|
||||
assert body["error"]["request_id"] == "trace-abc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTPException handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHTTPExceptionHandler:
|
||||
def test_400_preserves_detail(self, client):
|
||||
body = _assert_envelope(client.get("/http-400"), 400)
|
||||
assert body["error"]["message"] == "Bad input"
|
||||
assert body["error"]["code"] == "BAD_REQUEST"
|
||||
|
||||
def test_500_sanitizes_detail(self, client):
|
||||
body = _assert_envelope(client.get("/http-500"), 500)
|
||||
assert "secret" not in body["error"]["message"]
|
||||
assert "password" not in body["error"]["message"]
|
||||
assert body["error"]["message"] == GENERIC_5XX_MESSAGE
|
||||
assert body["error"]["code"] == "INTERNAL_ERROR"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SurfSenseError hierarchy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSurfSenseErrorHandler:
|
||||
def test_connector_error(self, client):
|
||||
body = _assert_envelope(client.get("/surfsense-connector"), 502)
|
||||
assert body["error"]["code"] == "CONNECTOR_ERROR"
|
||||
assert "GitHub" in body["error"]["message"]
|
||||
|
||||
def test_not_found_error(self, client):
|
||||
body = _assert_envelope(client.get("/surfsense-notfound"), 404)
|
||||
assert body["error"]["code"] == "NOT_FOUND"
|
||||
|
||||
def test_forbidden_error(self, client):
|
||||
body = _assert_envelope(client.get("/surfsense-forbidden"), 403)
|
||||
assert body["error"]["code"] == "FORBIDDEN"
|
||||
|
||||
def test_configuration_error(self, client):
|
||||
body = _assert_envelope(client.get("/surfsense-config"), 500)
|
||||
assert body["error"]["code"] == "CONFIGURATION_ERROR"
|
||||
|
||||
def test_database_error(self, client):
|
||||
body = _assert_envelope(client.get("/surfsense-db"), 500)
|
||||
assert body["error"]["code"] == "DATABASE_ERROR"
|
||||
|
||||
def test_external_service_error(self, client):
|
||||
body = _assert_envelope(client.get("/surfsense-external"), 502)
|
||||
assert body["error"]["code"] == "EXTERNAL_SERVICE_ERROR"
|
||||
|
||||
def test_validation_error_custom(self, client):
|
||||
body = _assert_envelope(client.get("/surfsense-validation"), 422)
|
||||
assert body["error"]["code"] == "VALIDATION_ERROR"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unhandled exception (catch-all)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnhandledException:
|
||||
def test_returns_500_generic_message(self, client):
|
||||
body = _assert_envelope(client.get("/unhandled"), 500)
|
||||
assert body["error"]["code"] == "INTERNAL_ERROR"
|
||||
assert body["error"]["message"] == GENERIC_5XX_MESSAGE
|
||||
assert "should never reach" not in json.dumps(body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RequestValidationError (pydantic / FastAPI)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidationErrorHandler:
|
||||
def test_missing_fields(self, client):
|
||||
resp = client.post("/validated", json={})
|
||||
body = _assert_envelope(resp, 422)
|
||||
assert body["error"]["code"] == "VALIDATION_ERROR"
|
||||
assert "required" in body["error"]["message"].lower()
|
||||
|
||||
def test_wrong_type(self, client):
|
||||
resp = client.post("/validated", json={"name": "test", "count": "not-a-number"})
|
||||
body = _assert_envelope(resp, 422)
|
||||
assert body["error"]["code"] == "VALIDATION_ERROR"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SurfSenseError class hierarchy unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSurfSenseErrorClasses:
|
||||
def test_base_defaults(self):
|
||||
err = SurfSenseError()
|
||||
assert err.code == "INTERNAL_ERROR"
|
||||
assert err.status_code == 500
|
||||
assert err.safe_for_client is True
|
||||
|
||||
def test_connector_error(self):
|
||||
err = ConnectorError("fail")
|
||||
assert err.code == "CONNECTOR_ERROR"
|
||||
assert err.status_code == 502
|
||||
|
||||
def test_database_error(self):
|
||||
err = DatabaseError()
|
||||
assert err.status_code == 500
|
||||
|
||||
def test_not_found_error(self):
|
||||
err = NotFoundError()
|
||||
assert err.status_code == 404
|
||||
|
||||
def test_forbidden_error(self):
|
||||
err = ForbiddenError()
|
||||
assert err.status_code == 403
|
||||
|
||||
def test_custom_code(self):
|
||||
err = ConnectorError("x", code="GITHUB_TOKEN_EXPIRED")
|
||||
assert err.code == "GITHUB_TOKEN_EXPIRED"
|
||||
Loading…
Add table
Add a link
Reference in a new issue