feat: enhance memory management and session handling in database operations

- Introduced a shielded async session context manager to ensure safe session closure during cancellations.
- Updated various database operations to utilize the new shielded session, preventing orphaned connections.
- Added environment variables to optimize glibc memory management, improving overall application performance.
- Implemented a function to trim the native heap, allowing for better memory reclamation on Linux systems.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-28 23:59:28 -08:00
parent dd3da2bc36
commit ecb0a25cc8
7 changed files with 76 additions and 17 deletions

View file

@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp
ENV PYTHONPATH=/app
ENV UVICORN_LOOP=asyncio
# Tune glibc malloc to return freed memory to the OS more aggressively.
# Without these, Python's gc.collect() frees objects but the underlying
# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation.
ENV MALLOC_MMAP_THRESHOLD_=65536
ENV MALLOC_TRIM_THRESHOLD_=131072
ENV MALLOC_MMAP_MAX_=65536
# SERVICE_ROLE controls which process this container runs:
# api FastAPI backend only (runs migrations on startup)
# worker Celery worker only

View file

@ -19,7 +19,7 @@ from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from app.db import shielded_async_session
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
@ -98,7 +98,7 @@ async def _browse_recent_documents(
if end_date is not None:
base_conditions.append(Document.updated_at <= end_date)
async with async_session_maker() as session:
async with shielded_async_session() as session:
doc_query = (
select(Document)
.options(joinedload(Document.search_space))
@ -739,7 +739,7 @@ async def search_knowledge_base_async(
try:
t_conn = time.perf_counter()
async with semaphore, async_session_maker() as isolated_session:
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
_, chunks = await getattr(svc, method_name)(**kwargs)
perf.info(
@ -756,7 +756,7 @@ async def search_knowledge_base_async(
# --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
try:
t_conn = time.perf_counter()
async with semaphore, async_session_maker() as isolated_session:
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
chunks = await svc._combined_rrf_search(
query_text=query,

View file

@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from app.db import Report, async_session_maker
from app.db import Report, shielded_async_session
from app.services.connector_service import ConnectorService
from app.services.llm_service import get_document_summary_llm
@ -717,7 +717,7 @@ def create_generate_report_tool(
async def _save_failed_report(error_msg: str) -> int | None:
"""Persist a failed report row using a short-lived session."""
try:
async with async_session_maker() as session:
async with shielded_async_session() as session:
failed_report = Report(
title=topic,
content=None,
@ -751,7 +751,7 @@ def create_generate_report_tool(
# ── Phase 1: READ (short-lived session) ──────────────────────
# Fetch parent report and LLM config, then close the session
# so no DB connection is held during the long LLM call.
async with async_session_maker() as read_session:
async with shielded_async_session() as read_session:
if parent_report_id:
parent_report = await read_session.get(Report, parent_report_id)
if parent_report:
@ -828,7 +828,7 @@ def create_generate_report_tool(
# Run all queries in parallel, each with its own session
async def _run_single_query(q: str) -> str:
async with async_session_maker() as kb_session:
async with shielded_async_session() as kb_session:
kb_connector_svc = ConnectorService(
kb_session, search_space_id
)
@ -1028,7 +1028,7 @@ def create_generate_report_tool(
# ── Phase 3: WRITE (short-lived session) ─────────────────────
# Save the report to the database, then close the session.
async with async_session_maker() as write_session:
async with shielded_async_session() as write_session:
report = Report(
title=topic,
content=report_content,

View file

@ -1,7 +1,9 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from enum import StrEnum
import anyio
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
from pgvector.sqlalchemy import Vector
@ -1867,6 +1869,26 @@ engine = create_async_engine(
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
@asynccontextmanager
async def shielded_async_session():
"""Cancellation-safe async session context manager.
Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel
scope when a client disconnects. A plain ``async with async_session_maker()``
has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the
scope, orphaning the underlying database connection.
This wrapper ensures ``session.close()`` always completes by running it
inside ``anyio.CancelScope(shield=True)``.
"""
session = async_session_maker()
try:
yield session
finally:
with anyio.CancelScope(shield=True):
await session.close()
async def setup_indexes():
async with engine.begin() as conn:
# Create indexes

View file

@ -31,8 +31,8 @@ from app.db import (
Permission,
SearchSpace,
User,
async_session_maker,
get_async_session,
shielded_async_session,
)
from app.schemas.new_chat import (
NewChatMessageAppend,
@ -1356,7 +1356,7 @@ async def regenerate_response(
# Uses a fresh session since stream_new_chat manages its own.
if streaming_completed and message_ids_to_delete:
try:
async with async_session_maker() as cleanup_session:
async with shielded_async_session() as cleanup_session:
for msg_id in message_ids_to_delete:
_res = await cleanup_session.execute(
select(NewChatMessage).filter(

View file

@ -49,6 +49,7 @@ from app.db import (
SearchSourceConnectorType,
SurfsenseDocsDocument,
async_session_maker,
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.services.chat_session_state_service import (
@ -58,7 +59,7 @@ from app.services.chat_session_state_service import (
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
_perf_log = get_perf_logger()
@ -1359,6 +1360,12 @@ async def stream_new_chat(
items=initial_items,
)
# These ORM objects (with eagerly-loaded chunks) can be very large.
# They're only needed to build context strings already copied into
# final_query / langchain_messages — release them before streaming.
del mentioned_documents, mentioned_surfsense_docs, recent_reports
del langchain_messages, final_query
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
@ -1483,7 +1490,7 @@ async def stream_new_chat(
await clear_ai_responding(session, chat_id)
except Exception:
try:
async with async_session_maker() as fresh_session:
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
@ -1501,9 +1508,7 @@ async def stream_new_chat(
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = sandbox_backend = None
mentioned_documents = mentioned_surfsense_docs = None
recent_reports = langchain_messages = input_state = None
stream_result = None
input_state = stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
@ -1513,6 +1518,7 @@ async def stream_new_chat(
collected,
chat_id,
)
trim_native_heap()
log_system_snapshot("stream_new_chat_END")
@ -1695,7 +1701,7 @@ async def stream_resume_chat(
await clear_ai_responding(session, chat_id)
except Exception:
try:
async with async_session_maker() as fresh_session:
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
@ -1721,4 +1727,5 @@ async def stream_resume_chat(
collected,
chat_id,
)
trim_native_heap()
log_system_snapshot("stream_resume_chat_END")

View file

@ -149,3 +149,26 @@ def log_system_snapshot(label: str = "system_snapshot") -> None:
snap["rss_delta_mb"],
snap["rss_mb"],
)
def trim_native_heap() -> bool:
"""Ask glibc to return free heap pages to the OS via ``malloc_trim(0)``.
On Linux (glibc), ``free()`` does not release memory back to the OS if
it sits below the brk watermark. ``malloc_trim`` forces the allocator
to give back as many freed pages as possible.
Returns True if trimming was performed, False otherwise (non-Linux or
libc unavailable).
"""
import ctypes
import sys
if sys.platform != "linux":
return False
try:
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
return True
except (OSError, AttributeError):
return False