feat: optimize document upload process and enhance memory management

- Increased maximum file upload limit from 10 to 50 to improve user experience.
- Implemented batch processing for document uploads to avoid proxy timeouts, splitting files into manageable chunks.
- Enhanced garbage collection in chat streaming functions to prevent memory leaks and improve performance.
- Added memory delta tracking in system snapshots for better monitoring of resource usage.
- Updated LLM router and service configurations to prevent unbounded internal accumulation and improve efficiency.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-28 17:22:34 -08:00
parent cc64e18501
commit d959a6a6c8
16 changed files with 219 additions and 187 deletions

View file

@ -1,4 +1,5 @@
import asyncio
import gc
import logging
import time
from collections import defaultdict
@ -212,18 +213,16 @@ def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
# Enable slow-callback detection (set PERF_DEBUG=1 env var to activate)
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
# sooner (default 700/10/10 → 700/10/5). This reduces peak RSS
# with minimal CPU overhead.
gc.set_threshold(700, 10, 5)
_enable_slow_callback_logging(threshold_sec=0.5)
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
# Setup LangGraph checkpointer tables for conversation persistence
await setup_checkpointer_tables()
# Initialize LLM Router for Auto mode load balancing
initialize_llm_router()
# Initialize Image Generation Router for Auto mode load balancing
initialize_image_gen_router()
# Seed Surfsense documentation (with timeout so a slow embedding API
# doesn't block startup indefinitely and make the container unresponsive)
try:
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
except TimeoutError:
@ -231,8 +230,11 @@ async def lifespan(app: FastAPI):
"Surfsense docs seeding timed out after 120s — skipping. "
"Docs will be indexed on the next restart."
)
log_system_snapshot("startup_complete")
yield
# Cleanup: close checkpointer connection on shutdown
await close_checkpointer()

View file

@ -1856,7 +1856,14 @@ class RefreshToken(Base, TimestampMixin):
return not self.is_expired and not self.is_revoked
engine = create_async_engine(DATABASE_URL)
engine = create_async_engine(
DATABASE_URL,
pool_size=30,
max_overflow=150,
pool_recycle=1800,
pool_pre_ping=True,
pool_timeout=30,
)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -133,6 +133,8 @@ async def create_documents_file_upload(
Requires DOCUMENTS_CREATE permission.
"""
import os
import tempfile
from datetime import datetime
from app.db import DocumentStatus
@ -143,7 +145,6 @@ async def create_documents_file_upload(
from app.utils.document_converters import generate_unique_identifier_hash
try:
# Check permission
await check_permission(
session,
user,
@ -179,69 +180,64 @@ async def create_documents_file_upload(
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
created_documents: list[Document] = []
files_to_process: list[
tuple[Document, str, str]
] = [] # (document, temp_path, filename)
skipped_duplicates = 0
duplicate_document_ids: list[int] = []
actual_total_size = 0
# ===== Read all files concurrently to avoid blocking the event loop =====
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
"""Read upload content and write to temp file off the event loop."""
content = await file.read()
file_size = len(content)
filename = file.filename or "unknown"
# ===== PHASE 1: Create pending documents for all files =====
# This makes ALL documents visible in the UI immediately with pending status
for file in files:
try:
import os
import tempfile
# Save file to temp location
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(file.filename or "")[1]
) as temp_file:
temp_path = temp_file.name
content = await file.read()
file_size = len(content)
if file_size > MAX_FILE_SIZE_BYTES:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
actual_total_size += file_size
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
with open(temp_path, "wb") as f:
f.write(content)
# Generate unique identifier for deduplication check
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.FILE, file.filename or "unknown", search_space_id
if file_size > MAX_FILE_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"File '{filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
def _write_temp() -> str:
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(filename)[1]
) as tmp:
tmp.write(content)
return tmp.name
temp_path = await asyncio.to_thread(_write_temp)
return temp_path, filename, file_size
saved_files = await asyncio.gather(*(_read_and_save(f) for f in files))
actual_total_size = sum(size for _, _, size in saved_files)
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
for temp_path, _, _ in saved_files:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
# ===== PHASE 1: Create pending documents for all files =====
created_documents: list[Document] = []
files_to_process: list[tuple[Document, str, str]] = []
skipped_duplicates = 0
duplicate_document_ids: list[int] = []
for temp_path, filename, file_size in saved_files:
try:
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.FILE, filename, search_space_id
)
# Check if document already exists (by unique identifier)
existing = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
if existing:
if DocumentStatus.is_state(existing.status, DocumentStatus.READY):
# True duplicate — content already indexed, skip
os.unlink(temp_path)
skipped_duplicates += 1
duplicate_document_ids.append(existing.id)
continue
# Existing document is stuck (failed/pending/processing)
# Reset it to pending and re-dispatch for processing
existing.status = DocumentStatus.pending()
existing.content = "Processing..."
existing.document_metadata = {
@ -251,50 +247,45 @@ async def create_documents_file_upload(
}
existing.updated_at = get_current_timestamp()
created_documents.append(existing)
files_to_process.append(
(existing, temp_path, file.filename or "unknown")
)
files_to_process.append((existing, temp_path, filename))
continue
# Create pending document (visible immediately in UI via ElectricSQL)
document = Document(
search_space_id=search_space_id,
title=file.filename or "Uploaded File",
title=filename if filename != "unknown" else "Uploaded File",
document_type=DocumentType.FILE,
document_metadata={
"FILE_NAME": file.filename,
"FILE_NAME": filename,
"file_size": file_size,
"upload_time": datetime.now().isoformat(),
},
content="Processing...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary, updated when ready
content="Processing...",
content_hash=unique_identifier_hash,
unique_identifier_hash=unique_identifier_hash,
embedding=None,
status=DocumentStatus.pending(), # Shows "pending" in UI
status=DocumentStatus.pending(),
updated_at=get_current_timestamp(),
created_by_id=str(user.id),
)
session.add(document)
created_documents.append(document)
files_to_process.append(
(document, temp_path, file.filename or "unknown")
)
files_to_process.append((document, temp_path, filename))
except HTTPException:
raise
except Exception as e:
os.unlink(temp_path)
raise HTTPException(
status_code=422,
detail=f"Failed to process file {file.filename}: {e!s}",
detail=f"Failed to process file {filename}: {e!s}",
) from e
# Commit all pending documents - they appear in UI immediately via ElectricSQL
if created_documents:
await session.commit()
# Refresh to get generated IDs
for doc in created_documents:
await session.refresh(doc)
# ===== PHASE 2: Dispatch tasks for each file =====
# Each task will update document status: pending → processing → ready/failed
for document, temp_path, filename in files_to_process:
await dispatcher.dispatch_file_processing(
document_id=document.id,

View file

@ -16,6 +16,7 @@ import re
import time
from typing import Any
import litellm
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import BaseChatModel
@ -29,6 +30,9 @@ from litellm.exceptions import (
from app.utils.perf import get_perf_logger
litellm.json_logs = False
litellm.store_audit_logs = False
logger = logging.getLogger(__name__)
_CONTEXT_OVERFLOW_PATTERNS = re.compile(

View file

@ -19,6 +19,13 @@ from app.services.llm_router_service import (
# Configure litellm to automatically drop unsupported parameters
litellm.drop_params = True
# Memory controls: prevent unbounded internal accumulation
litellm.telemetry = False
litellm.cache = None
litellm.success_callback = []
litellm.failure_callback = []
litellm.input_callback = []
logger = logging.getLogger(__name__)

View file

@ -1 +1,28 @@
"""Celery tasks package."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config
_celery_engine = None
_celery_session_maker = None
def get_celery_session_maker() -> async_sessionmaker:
"""Return a shared async session maker for Celery tasks.
A single NullPool engine is created per worker process and reused
across all task invocations to avoid leaking engine objects.
"""
global _celery_engine, _celery_session_maker
if _celery_session_maker is None:
_celery_engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
_celery_session_maker = async_sessionmaker(
_celery_engine, expire_on_commit=False
)
return _celery_session_maker

View file

@ -3,11 +3,8 @@
import logging
import traceback
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@ -42,20 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
)
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="index_slack_messages", bind=True)
def index_slack_messages_task(
self,

View file

@ -4,15 +4,13 @@ import logging
from sqlalchemy import delete, select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import selectinload
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Document
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.utils.document_converters import (
create_document_chunks,
generate_document_summary,
@ -21,16 +19,6 @@ from app.utils.document_converters import (
logger = logging.getLogger(__name__)
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="reindex_document", bind=True)
def reindex_document_task(self, document_id: int, user_id: str):
"""

View file

@ -5,13 +5,11 @@ import logging
import os
from uuid import UUID
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.document_processors import (
add_extension_received_document,
add_youtube_video_document,
@ -91,20 +89,6 @@ async def _run_heartbeat_loop(notification_id: int):
pass # Normal cancellation when task completes
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="process_extension_document", bind=True)
def process_extension_document_task(
self, individual_document_dict, search_space_id: int, user_id: str

View file

@ -5,14 +5,13 @@ import logging
import sys
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config
from app.db import Podcast, PodcastStatus
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@ -25,20 +24,6 @@ if sys.platform.startswith("win"):
)
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================

View file

@ -3,28 +3,16 @@
import logging
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
from app.tasks.celery_tasks import get_celery_session_maker
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="check_periodic_schedules")
def check_periodic_schedules_task():
"""

View file

@ -29,20 +29,17 @@ from datetime import UTC, datetime
import redis
from sqlalchemy import and_, or_, text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Document, DocumentStatus, Notification
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
# Redis client for checking heartbeats
_redis_client: redis.Redis | None = None
# Error messages shown to users when tasks are interrupted
STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry."
STALE_PROCESSING_ERROR_MESSAGE = "Syncing was interrupted unexpectedly. Please retry."
@ -60,16 +57,6 @@ def _get_heartbeat_key(notification_id: int) -> str:
return f"indexing:heartbeat:{notification_id}"
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="cleanup_stale_indexing_notifications")
def cleanup_stale_indexing_notifications_task():
"""

View file

@ -1477,15 +1477,21 @@ async def stream_new_chat(
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
# Trigger a GC pass so LangGraph agent graphs, tool closures, and
# LLM wrappers with potential circular refs are reclaimed promptly.
collected = gc.collect()
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = sandbox_backend = None
mentioned_documents = mentioned_surfsense_docs = None
recent_reports = langchain_messages = input_state = None
stream_result = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
_perf_log.info(
"[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
collected,
chat_id,
)
log_system_snapshot("stream_new_chat_END")
async def stream_resume_chat(
@ -1673,10 +1679,15 @@ async def stream_resume_chat(
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
collected = gc.collect()
agent = llm = connector_service = sandbox_backend = None
stream_result = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
_perf_log.info(
"[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
collected,
chat_id,
)
log_system_snapshot("stream_resume_chat_END")

View file

@ -9,6 +9,7 @@ Provides:
- RequestPerfMiddleware for per-request timing
"""
import gc
import logging
import os
import time
@ -16,6 +17,7 @@ from contextlib import asynccontextmanager, contextmanager
from typing import Any
_perf_log: logging.Logger | None = None
_last_rss_mb: float = 0.0
def get_perf_logger() -> logging.Logger:
@ -73,20 +75,29 @@ def system_snapshot() -> dict[str, Any]:
Returns a dict with:
- rss_mb: Resident Set Size in MB
- rss_delta_mb: Change in RSS since the last snapshot
- cpu_percent: CPU usage % since last call (per-process)
- threads: number of active threads
- open_fds: number of open file descriptors (Linux only)
- asyncio_tasks: number of asyncio tasks currently alive
- gc_counts: tuple of object counts per gc generation
"""
import asyncio
global _last_rss_mb
snapshot: dict[str, Any] = {}
try:
import psutil
proc = psutil.Process(os.getpid())
mem = proc.memory_info()
snapshot["rss_mb"] = round(mem.rss / 1024 / 1024, 1)
rss_mb = round(mem.rss / 1024 / 1024, 1)
snapshot["rss_mb"] = rss_mb
snapshot["rss_delta_mb"] = (
round(rss_mb - _last_rss_mb, 1) if _last_rss_mb else 0.0
)
_last_rss_mb = rss_mb
snapshot["cpu_percent"] = proc.cpu_percent(interval=None)
snapshot["threads"] = proc.num_threads()
try:
@ -95,6 +106,7 @@ def system_snapshot() -> dict[str, Any]:
snapshot["open_fds"] = -1
except ImportError:
snapshot["rss_mb"] = -1
snapshot["rss_delta_mb"] = 0.0
snapshot["cpu_percent"] = -1
snapshot["threads"] = -1
snapshot["open_fds"] = -1
@ -105,18 +117,35 @@ def system_snapshot() -> dict[str, Any]:
except RuntimeError:
snapshot["asyncio_tasks"] = -1
snapshot["gc_counts"] = gc.get_count()
return snapshot
def log_system_snapshot(label: str = "system_snapshot") -> None:
"""Capture and log a system snapshot."""
"""Capture and log a system snapshot with memory delta tracking."""
snap = system_snapshot()
delta_str = ""
if snap["rss_delta_mb"]:
sign = "+" if snap["rss_delta_mb"] > 0 else ""
delta_str = f" delta={sign}{snap['rss_delta_mb']}MB"
get_perf_logger().info(
"[%s] rss=%.1fMB cpu=%.1f%% threads=%d fds=%d asyncio_tasks=%d",
"[%s] rss=%.1fMB%s cpu=%.1f%% threads=%d fds=%d asyncio_tasks=%d gc=%s",
label,
snap["rss_mb"],
delta_str,
snap["cpu_percent"],
snap["threads"],
snap["open_fds"],
snap["asyncio_tasks"],
snap["gc_counts"],
)
if snap["rss_mb"] > 0 and snap["rss_delta_mb"] > 500:
get_perf_logger().warning(
"[MEMORY_SPIKE] %s: RSS jumped by %.1fMB (now %.1fMB). "
"Possible leak — check recent operations.",
label,
snap["rss_delta_mb"],
snap["rss_mb"],
)

View file

@ -111,8 +111,8 @@ const FILE_TYPE_CONFIG: Record<string, Record<string, string[]>> = {
const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5";
// Upload limits
const MAX_FILES = 10;
// Upload limits — files are sent in batches of 5 to avoid proxy timeouts
const MAX_FILES = 50;
const MAX_TOTAL_SIZE_MB = 200;
const MAX_TOTAL_SIZE_BYTES = MAX_TOTAL_SIZE_MB * 1024 * 1024;

View file

@ -109,7 +109,9 @@ class DocumentsApiService {
};
/**
* Upload document files
* Upload document files in batches to avoid proxy/LB timeouts.
* Files are split into chunks of UPLOAD_BATCH_SIZE and sent as separate
* requests. Results are aggregated into a single response.
*/
uploadDocument = async (request: UploadDocumentRequest) => {
const parsedRequest = uploadDocumentRequest.safeParse(request);
@ -121,17 +123,54 @@ class DocumentsApiService {
throw new ValidationError(`Invalid request: ${errorMessage}`);
}
// Create FormData for file upload
const formData = new FormData();
parsedRequest.data.files.forEach((file) => {
formData.append("files", file);
});
formData.append("search_space_id", String(parsedRequest.data.search_space_id));
formData.append("should_summarize", String(parsedRequest.data.should_summarize));
const { files, search_space_id, should_summarize } = parsedRequest.data;
const UPLOAD_BATCH_SIZE = 5;
return baseApiService.postFormData(`/api/v1/documents/fileupload`, uploadDocumentResponse, {
body: formData,
});
const batches: File[][] = [];
for (let i = 0; i < files.length; i += UPLOAD_BATCH_SIZE) {
batches.push(files.slice(i, i + UPLOAD_BATCH_SIZE));
}
const allDocumentIds: number[] = [];
const allDuplicateIds: number[] = [];
let totalFiles = 0;
let pendingFiles = 0;
let skippedDuplicates = 0;
for (const batch of batches) {
const formData = new FormData();
batch.forEach((file) => formData.append("files", file));
formData.append("search_space_id", String(search_space_id));
formData.append("should_summarize", String(should_summarize));
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 120_000);
try {
const result = await baseApiService.postFormData(
`/api/v1/documents/fileupload`,
uploadDocumentResponse,
{ body: formData, signal: controller.signal }
);
allDocumentIds.push(...(result.document_ids ?? []));
allDuplicateIds.push(...(result.duplicate_document_ids ?? []));
totalFiles += result.total_files ?? batch.length;
pendingFiles += result.pending_files ?? 0;
skippedDuplicates += result.skipped_duplicates ?? 0;
} finally {
clearTimeout(timeoutId);
}
}
return {
message: "Files uploaded for processing" as const,
document_ids: allDocumentIds,
duplicate_document_ids: allDuplicateIds,
total_files: totalFiles,
pending_files: pendingFiles,
skipped_duplicates: skippedDuplicates,
};
};
/**