mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: enhance memory management and session handling in database operations
- Introduced a shielded async session context manager to ensure safe session closure during cancellations. - Updated various database operations to utilize the new shielded session, preventing orphaned connections. - Added environment variables to optimize glibc memory management, improving overall application performance. - Implemented a function to trim the native heap, allowing for better memory reclamation on Linux systems.
This commit is contained in:
parent
dd3da2bc36
commit
ecb0a25cc8
7 changed files with 76 additions and 17 deletions
|
|
@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp
|
|||
ENV PYTHONPATH=/app
|
||||
ENV UVICORN_LOOP=asyncio
|
||||
|
||||
# Tune glibc malloc to return freed memory to the OS more aggressively.
|
||||
# Without these, Python's gc.collect() frees objects but the underlying
|
||||
# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation.
|
||||
ENV MALLOC_MMAP_THRESHOLD_=65536
|
||||
ENV MALLOC_TRIM_THRESHOLD_=131072
|
||||
ENV MALLOC_MMAP_MAX_=65536
|
||||
|
||||
# SERVICE_ROLE controls which process this container runs:
|
||||
# api – FastAPI backend only (runs migrations on startup)
|
||||
# worker – Celery worker only
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from langchain_core.tools import StructuredTool
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
from app.db import shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ async def _browse_recent_documents(
|
|||
if end_date is not None:
|
||||
base_conditions.append(Document.updated_at <= end_date)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
async with shielded_async_session() as session:
|
||||
doc_query = (
|
||||
select(Document)
|
||||
.options(joinedload(Document.search_space))
|
||||
|
|
@ -739,7 +739,7 @@ async def search_knowledge_base_async(
|
|||
|
||||
try:
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, async_session_maker() as isolated_session:
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
_, chunks = await getattr(svc, method_name)(**kwargs)
|
||||
perf.info(
|
||||
|
|
@ -756,7 +756,7 @@ async def search_knowledge_base_async(
|
|||
# --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
|
||||
try:
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, async_session_maker() as isolated_session:
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
chunks = await svc._combined_rrf_search(
|
||||
query_text=query,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event
|
|||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.db import Report, async_session_maker
|
||||
from app.db import Report, shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
|
|
@ -717,7 +717,7 @@ def create_generate_report_tool(
|
|||
async def _save_failed_report(error_msg: str) -> int | None:
|
||||
"""Persist a failed report row using a short-lived session."""
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
async with shielded_async_session() as session:
|
||||
failed_report = Report(
|
||||
title=topic,
|
||||
content=None,
|
||||
|
|
@ -751,7 +751,7 @@ def create_generate_report_tool(
|
|||
# ── Phase 1: READ (short-lived session) ──────────────────────
|
||||
# Fetch parent report and LLM config, then close the session
|
||||
# so no DB connection is held during the long LLM call.
|
||||
async with async_session_maker() as read_session:
|
||||
async with shielded_async_session() as read_session:
|
||||
if parent_report_id:
|
||||
parent_report = await read_session.get(Report, parent_report_id)
|
||||
if parent_report:
|
||||
|
|
@ -828,7 +828,7 @@ def create_generate_report_tool(
|
|||
|
||||
# Run all queries in parallel, each with its own session
|
||||
async def _run_single_query(q: str) -> str:
|
||||
async with async_session_maker() as kb_session:
|
||||
async with shielded_async_session() as kb_session:
|
||||
kb_connector_svc = ConnectorService(
|
||||
kb_session, search_space_id
|
||||
)
|
||||
|
|
@ -1028,7 +1028,7 @@ def create_generate_report_tool(
|
|||
|
||||
# ── Phase 3: WRITE (short-lived session) ─────────────────────
|
||||
# Save the report to the database, then close the session.
|
||||
async with async_session_maker() as write_session:
|
||||
async with shielded_async_session() as write_session:
|
||||
report = Report(
|
||||
title=topic,
|
||||
content=report_content,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
|
||||
import anyio
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
|
@ -1867,6 +1869,26 @@ engine = create_async_engine(
|
|||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def shielded_async_session():
|
||||
"""Cancellation-safe async session context manager.
|
||||
|
||||
Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel
|
||||
scope when a client disconnects. A plain ``async with async_session_maker()``
|
||||
has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the
|
||||
scope, orphaning the underlying database connection.
|
||||
|
||||
This wrapper ensures ``session.close()`` always completes by running it
|
||||
inside ``anyio.CancelScope(shield=True)``.
|
||||
"""
|
||||
session = async_session_maker()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
await session.close()
|
||||
|
||||
|
||||
async def setup_indexes():
|
||||
async with engine.begin() as conn:
|
||||
# Create indexes
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ from app.db import (
|
|||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
async_session_maker,
|
||||
get_async_session,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.schemas.new_chat import (
|
||||
NewChatMessageAppend,
|
||||
|
|
@ -1356,7 +1356,7 @@ async def regenerate_response(
|
|||
# Uses a fresh session since stream_new_chat manages its own.
|
||||
if streaming_completed and message_ids_to_delete:
|
||||
try:
|
||||
async with async_session_maker() as cleanup_session:
|
||||
async with shielded_async_session() as cleanup_session:
|
||||
for msg_id in message_ids_to_delete:
|
||||
_res = await cleanup_session.execute(
|
||||
select(NewChatMessage).filter(
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from app.db import (
|
|||
SearchSourceConnectorType,
|
||||
SurfsenseDocsDocument,
|
||||
async_session_maker,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.chat_session_state_service import (
|
||||
|
|
@ -58,7 +59,7 @@ from app.services.chat_session_state_service import (
|
|||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
|
@ -1359,6 +1360,12 @@ async def stream_new_chat(
|
|||
items=initial_items,
|
||||
)
|
||||
|
||||
# These ORM objects (with eagerly-loaded chunks) can be very large.
|
||||
# They're only needed to build context strings already copied into
|
||||
# final_query / langchain_messages — release them before streaming.
|
||||
del mentioned_documents, mentioned_surfsense_docs, recent_reports
|
||||
del langchain_messages, final_query
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
|
|
@ -1483,7 +1490,7 @@ async def stream_new_chat(
|
|||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
|
|
@ -1501,9 +1508,7 @@ async def stream_new_chat(
|
|||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
agent = llm = connector_service = sandbox_backend = None
|
||||
mentioned_documents = mentioned_surfsense_docs = None
|
||||
recent_reports = langchain_messages = input_state = None
|
||||
stream_result = None
|
||||
input_state = stream_result = None
|
||||
session = None
|
||||
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
|
|
@ -1513,6 +1518,7 @@ async def stream_new_chat(
|
|||
collected,
|
||||
chat_id,
|
||||
)
|
||||
trim_native_heap()
|
||||
log_system_snapshot("stream_new_chat_END")
|
||||
|
||||
|
||||
|
|
@ -1695,7 +1701,7 @@ async def stream_resume_chat(
|
|||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
|
|
@ -1721,4 +1727,5 @@ async def stream_resume_chat(
|
|||
collected,
|
||||
chat_id,
|
||||
)
|
||||
trim_native_heap()
|
||||
log_system_snapshot("stream_resume_chat_END")
|
||||
|
|
|
|||
|
|
@ -149,3 +149,26 @@ def log_system_snapshot(label: str = "system_snapshot") -> None:
|
|||
snap["rss_delta_mb"],
|
||||
snap["rss_mb"],
|
||||
)
|
||||
|
||||
|
||||
def trim_native_heap() -> bool:
|
||||
"""Ask glibc to return free heap pages to the OS via ``malloc_trim(0)``.
|
||||
|
||||
On Linux (glibc), ``free()`` does not release memory back to the OS if
|
||||
it sits below the brk watermark. ``malloc_trim`` forces the allocator
|
||||
to give back as many freed pages as possible.
|
||||
|
||||
Returns True if trimming was performed, False otherwise (non-Linux or
|
||||
libc unavailable).
|
||||
"""
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
if sys.platform != "linux":
|
||||
return False
|
||||
try:
|
||||
libc = ctypes.CDLL("libc.so.6")
|
||||
libc.malloc_trim(0)
|
||||
return True
|
||||
except (OSError, AttributeError):
|
||||
return False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue