feat: enhance memory management and session handling in database operations

- Introduced a shielded async session context manager to ensure safe session closure during cancellations.
- Updated various database operations to utilize the new shielded session, preventing orphaned connections.
- Added environment variables to optimize glibc memory management, improving overall application performance.
- Implemented a function to trim the native heap, allowing for better memory reclamation on Linux systems.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-28 23:59:28 -08:00
parent dd3da2bc36
commit ecb0a25cc8
7 changed files with 76 additions and 17 deletions

View file

@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ENV UVICORN_LOOP=asyncio ENV UVICORN_LOOP=asyncio
# Tune glibc malloc to return freed memory to the OS more aggressively.
# Without these, Python's gc.collect() frees objects but the underlying
# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation.
ENV MALLOC_MMAP_THRESHOLD_=65536
ENV MALLOC_TRIM_THRESHOLD_=131072
ENV MALLOC_MMAP_MAX_=65536
# SERVICE_ROLE controls which process this container runs: # SERVICE_ROLE controls which process this container runs:
# api FastAPI backend only (runs migrations on startup) # api FastAPI backend only (runs migrations on startup)
# worker Celery worker only # worker Celery worker only

View file

@ -19,7 +19,7 @@ from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker from app.db import shielded_async_session
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
@ -98,7 +98,7 @@ async def _browse_recent_documents(
if end_date is not None: if end_date is not None:
base_conditions.append(Document.updated_at <= end_date) base_conditions.append(Document.updated_at <= end_date)
async with async_session_maker() as session: async with shielded_async_session() as session:
doc_query = ( doc_query = (
select(Document) select(Document)
.options(joinedload(Document.search_space)) .options(joinedload(Document.search_space))
@ -739,7 +739,7 @@ async def search_knowledge_base_async(
try: try:
t_conn = time.perf_counter() t_conn = time.perf_counter()
async with semaphore, async_session_maker() as isolated_session: async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id) svc = ConnectorService(isolated_session, search_space_id)
_, chunks = await getattr(svc, method_name)(**kwargs) _, chunks = await getattr(svc, method_name)(**kwargs)
perf.info( perf.info(
@ -756,7 +756,7 @@ async def search_knowledge_base_async(
# --- Optimization 3: call _combined_rrf_search directly with shared embedding --- # --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
try: try:
t_conn = time.perf_counter() t_conn = time.perf_counter()
async with semaphore, async_session_maker() as isolated_session: async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id) svc = ConnectorService(isolated_session, search_space_id)
chunks = await svc._combined_rrf_search( chunks = await svc._combined_rrf_search(
query_text=query, query_text=query,

View file

@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from app.db import Report, async_session_maker from app.db import Report, shielded_async_session
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.services.llm_service import get_document_summary_llm from app.services.llm_service import get_document_summary_llm
@ -717,7 +717,7 @@ def create_generate_report_tool(
async def _save_failed_report(error_msg: str) -> int | None: async def _save_failed_report(error_msg: str) -> int | None:
"""Persist a failed report row using a short-lived session.""" """Persist a failed report row using a short-lived session."""
try: try:
async with async_session_maker() as session: async with shielded_async_session() as session:
failed_report = Report( failed_report = Report(
title=topic, title=topic,
content=None, content=None,
@ -751,7 +751,7 @@ def create_generate_report_tool(
# ── Phase 1: READ (short-lived session) ────────────────────── # ── Phase 1: READ (short-lived session) ──────────────────────
# Fetch parent report and LLM config, then close the session # Fetch parent report and LLM config, then close the session
# so no DB connection is held during the long LLM call. # so no DB connection is held during the long LLM call.
async with async_session_maker() as read_session: async with shielded_async_session() as read_session:
if parent_report_id: if parent_report_id:
parent_report = await read_session.get(Report, parent_report_id) parent_report = await read_session.get(Report, parent_report_id)
if parent_report: if parent_report:
@ -828,7 +828,7 @@ def create_generate_report_tool(
# Run all queries in parallel, each with its own session # Run all queries in parallel, each with its own session
async def _run_single_query(q: str) -> str: async def _run_single_query(q: str) -> str:
async with async_session_maker() as kb_session: async with shielded_async_session() as kb_session:
kb_connector_svc = ConnectorService( kb_connector_svc = ConnectorService(
kb_session, search_space_id kb_session, search_space_id
) )
@ -1028,7 +1028,7 @@ def create_generate_report_tool(
# ── Phase 3: WRITE (short-lived session) ───────────────────── # ── Phase 3: WRITE (short-lived session) ─────────────────────
# Save the report to the database, then close the session. # Save the report to the database, then close the session.
async with async_session_maker() as write_session: async with shielded_async_session() as write_session:
report = Report( report = Report(
title=topic, title=topic,
content=report_content, content=report_content,

View file

@ -1,7 +1,9 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import StrEnum from enum import StrEnum
import anyio
from fastapi import Depends from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
@ -1867,6 +1869,26 @@ engine = create_async_engine(
async_session_maker = async_sessionmaker(engine, expire_on_commit=False) async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
@asynccontextmanager
async def shielded_async_session():
"""Cancellation-safe async session context manager.
Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel
scope when a client disconnects. A plain ``async with async_session_maker()``
has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the
scope, orphaning the underlying database connection.
This wrapper ensures ``session.close()`` always completes by running it
inside ``anyio.CancelScope(shield=True)``.
"""
session = async_session_maker()
try:
yield session
finally:
with anyio.CancelScope(shield=True):
await session.close()
async def setup_indexes(): async def setup_indexes():
async with engine.begin() as conn: async with engine.begin() as conn:
# Create indexes # Create indexes

View file

@ -31,8 +31,8 @@ from app.db import (
Permission, Permission,
SearchSpace, SearchSpace,
User, User,
async_session_maker,
get_async_session, get_async_session,
shielded_async_session,
) )
from app.schemas.new_chat import ( from app.schemas.new_chat import (
NewChatMessageAppend, NewChatMessageAppend,
@ -1356,7 +1356,7 @@ async def regenerate_response(
# Uses a fresh session since stream_new_chat manages its own. # Uses a fresh session since stream_new_chat manages its own.
if streaming_completed and message_ids_to_delete: if streaming_completed and message_ids_to_delete:
try: try:
async with async_session_maker() as cleanup_session: async with shielded_async_session() as cleanup_session:
for msg_id in message_ids_to_delete: for msg_id in message_ids_to_delete:
_res = await cleanup_session.execute( _res = await cleanup_session.execute(
select(NewChatMessage).filter( select(NewChatMessage).filter(

View file

@ -49,6 +49,7 @@ from app.db import (
SearchSourceConnectorType, SearchSourceConnectorType,
SurfsenseDocsDocument, SurfsenseDocsDocument,
async_session_maker, async_session_maker,
shielded_async_session,
) )
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.services.chat_session_state_service import ( from app.services.chat_session_state_service import (
@ -58,7 +59,7 @@ from app.services.chat_session_state_service import (
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
_perf_log = get_perf_logger() _perf_log = get_perf_logger()
@ -1359,6 +1360,12 @@ async def stream_new_chat(
items=initial_items, items=initial_items,
) )
# These ORM objects (with eagerly-loaded chunks) can be very large.
# They're only needed to build context strings already copied into
# final_query / langchain_messages — release them before streaming.
del mentioned_documents, mentioned_surfsense_docs, recent_reports
del langchain_messages, final_query
_t_stream_start = time.perf_counter() _t_stream_start = time.perf_counter()
_first_event_logged = False _first_event_logged = False
async for sse in _stream_agent_events( async for sse in _stream_agent_events(
@ -1483,7 +1490,7 @@ async def stream_new_chat(
await clear_ai_responding(session, chat_id) await clear_ai_responding(session, chat_id)
except Exception: except Exception:
try: try:
async with async_session_maker() as fresh_session: async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id) await clear_ai_responding(fresh_session, chat_id)
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
@ -1501,9 +1508,7 @@ async def stream_new_chat(
# Break circular refs held by the agent graph, tools, and LLM # Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass. # wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = sandbox_backend = None agent = llm = connector_service = sandbox_backend = None
mentioned_documents = mentioned_surfsense_docs = None input_state = stream_result = None
recent_reports = langchain_messages = input_state = None
stream_result = None
session = None session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2) collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
@ -1513,6 +1518,7 @@ async def stream_new_chat(
collected, collected,
chat_id, chat_id,
) )
trim_native_heap()
log_system_snapshot("stream_new_chat_END") log_system_snapshot("stream_new_chat_END")
@ -1695,7 +1701,7 @@ async def stream_resume_chat(
await clear_ai_responding(session, chat_id) await clear_ai_responding(session, chat_id)
except Exception: except Exception:
try: try:
async with async_session_maker() as fresh_session: async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id) await clear_ai_responding(fresh_session, chat_id)
except Exception: except Exception:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
@ -1721,4 +1727,5 @@ async def stream_resume_chat(
collected, collected,
chat_id, chat_id,
) )
trim_native_heap()
log_system_snapshot("stream_resume_chat_END") log_system_snapshot("stream_resume_chat_END")

View file

@ -149,3 +149,26 @@ def log_system_snapshot(label: str = "system_snapshot") -> None:
snap["rss_delta_mb"], snap["rss_delta_mb"],
snap["rss_mb"], snap["rss_mb"],
) )
def trim_native_heap() -> bool:
"""Ask glibc to return free heap pages to the OS via ``malloc_trim(0)``.
On Linux (glibc), ``free()`` does not release memory back to the OS if
it sits below the brk watermark. ``malloc_trim`` forces the allocator
to give back as many freed pages as possible.
Returns True if trimming was performed, False otherwise (non-Linux or
libc unavailable).
"""
import ctypes
import sys
if sys.platform != "linux":
return False
try:
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
return True
except (OSError, AttributeError):
return False