diff --git a/surfsense_backend/app/agents/new_chat/checkpointer.py b/surfsense_backend/app/agents/new_chat/checkpointer.py index 637b2926f..04ecfbdea 100644 --- a/surfsense_backend/app/agents/new_chat/checkpointer.py +++ b/surfsense_backend/app/agents/new_chat/checkpointer.py @@ -3,15 +3,25 @@ PostgreSQL-based checkpointer for LangGraph agents. This module provides a persistent checkpointer using AsyncPostgresSaver that stores conversation state in the PostgreSQL database. + +Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle +connection lifecycle, health checks, and automatic reconnection, +preventing 'the connection is closed' errors in long-running deployments. """ +import logging + from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from psycopg.rows import dict_row +from psycopg_pool import AsyncConnectionPool from app.config import config +logger = logging.getLogger(__name__) + # Global checkpointer instance (initialized lazily) _checkpointer: AsyncPostgresSaver | None = None -_checkpointer_context = None # Store the context manager for cleanup +_connection_pool: AsyncConnectionPool | None = None _checkpointer_initialized: bool = False @@ -38,26 +48,65 @@ def get_postgres_connection_string() -> str: return db_url +async def _create_checkpointer() -> AsyncPostgresSaver: + """ + Create a new AsyncPostgresSaver backed by a connection pool. + + The connection pool automatically handles: + - Connection health checks before use + - Reconnection when connections die (idle timeout, DB restart, etc.) + - Connection lifecycle management (max_lifetime, max_idle) + """ + global _connection_pool + + conn_string = get_postgres_connection_string() + + _connection_pool = AsyncConnectionPool( + conninfo=conn_string, + min_size=2, + max_size=10, + # Connections are recycled after 30 minutes to avoid stale connections + max_lifetime=1800, + # Idle connections are closed after 5 minutes + max_idle=300, + open=False, + # Connection kwargs required by AsyncPostgresSaver: + # - autocommit: required for .setup() to commit checkpoint tables + # - prepare_threshold: disable prepared statements for compatibility + # - row_factory: checkpointer accesses rows as dicts (row["column"]) + kwargs={ + "autocommit": True, + "prepare_threshold": 0, + "row_factory": dict_row, + }, + ) + await _connection_pool.open(wait=True) + + checkpointer = AsyncPostgresSaver(conn=_connection_pool) + logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool") + return checkpointer + + async def get_checkpointer() -> AsyncPostgresSaver: """ Get or create the global AsyncPostgresSaver instance. This function: - 1. Creates the checkpointer if it doesn't exist + 1. Creates the checkpointer with a connection pool if it doesn't exist 2. Sets up the required database tables on first call 3. Returns the cached instance on subsequent calls + The underlying connection pool handles reconnection automatically, + so a stale/closed connection will not cause OperationalError. + Returns: AsyncPostgresSaver: The configured checkpointer instance """ - global _checkpointer, _checkpointer_context, _checkpointer_initialized + global _checkpointer, _checkpointer_initialized if _checkpointer is None: - conn_string = get_postgres_connection_string() - # from_conn_string returns an async context manager - # We need to enter the context to get the actual checkpointer - _checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string) - _checkpointer = await _checkpointer_context.__aenter__() + _checkpointer = await _create_checkpointer() + _checkpointer_initialized = False # Setup tables on first call (idempotent) if not _checkpointer_initialized: @@ -75,20 +124,21 @@ async def setup_checkpointer_tables() -> None: tables exist before any agent calls. """ await get_checkpointer() - print("[Checkpointer] PostgreSQL checkpoint tables ready") + logger.info("[Checkpointer] PostgreSQL checkpoint tables ready") async def close_checkpointer() -> None: """ - Close the checkpointer connection. + Close the checkpointer connection pool. This should be called during application shutdown. """ - global _checkpointer, _checkpointer_context, _checkpointer_initialized + global _checkpointer, _connection_pool, _checkpointer_initialized - if _checkpointer_context is not None: - await _checkpointer_context.__aexit__(None, None, None) - _checkpointer = None - _checkpointer_context = None - _checkpointer_initialized = False - print("[Checkpointer] PostgreSQL connection closed") + if _connection_pool is not None: + await _connection_pool.close() + logger.info("[Checkpointer] PostgreSQL connection pool closed") + + _checkpointer = None + _connection_pool = None + _checkpointer_initialized = False