feat: implement connection pooling for AsyncPostgresSaver in checkpointer

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-05 17:32:43 -08:00
parent f85adefe5e
commit af16b6656c

View file

@ -3,15 +3,25 @@ PostgreSQL-based checkpointer for LangGraph agents.
This module provides a persistent checkpointer using AsyncPostgresSaver This module provides a persistent checkpointer using AsyncPostgresSaver
that stores conversation state in the PostgreSQL database. that stores conversation state in the PostgreSQL database.
Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle
connection lifecycle, health checks, and automatic reconnection,
preventing 'the connection is closed' errors in long-running deployments.
""" """
import logging
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from app.config import config from app.config import config
logger = logging.getLogger(__name__)
# Global checkpointer instance (initialized lazily) # Global checkpointer instance (initialized lazily)
_checkpointer: AsyncPostgresSaver | None = None _checkpointer: AsyncPostgresSaver | None = None
_checkpointer_context = None # Store the context manager for cleanup _connection_pool: AsyncConnectionPool | None = None
_checkpointer_initialized: bool = False _checkpointer_initialized: bool = False
@ -38,26 +48,65 @@ def get_postgres_connection_string() -> str:
return db_url return db_url
async def _create_checkpointer() -> AsyncPostgresSaver:
"""
Create a new AsyncPostgresSaver backed by a connection pool.
The connection pool automatically handles:
- Connection health checks before use
- Reconnection when connections die (idle timeout, DB restart, etc.)
- Connection lifecycle management (max_lifetime, max_idle)
"""
global _connection_pool
conn_string = get_postgres_connection_string()
_connection_pool = AsyncConnectionPool(
conninfo=conn_string,
min_size=2,
max_size=10,
# Connections are recycled after 30 minutes to avoid stale connections
max_lifetime=1800,
# Idle connections are closed after 5 minutes
max_idle=300,
open=False,
# Connection kwargs required by AsyncPostgresSaver:
# - autocommit: required for .setup() to commit checkpoint tables
# - prepare_threshold: disable prepared statements for compatibility
# - row_factory: checkpointer accesses rows as dicts (row["column"])
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
},
)
await _connection_pool.open(wait=True)
checkpointer = AsyncPostgresSaver(conn=_connection_pool)
logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool")
return checkpointer
async def get_checkpointer() -> AsyncPostgresSaver: async def get_checkpointer() -> AsyncPostgresSaver:
""" """
Get or create the global AsyncPostgresSaver instance. Get or create the global AsyncPostgresSaver instance.
This function: This function:
1. Creates the checkpointer if it doesn't exist 1. Creates the checkpointer with a connection pool if it doesn't exist
2. Sets up the required database tables on first call 2. Sets up the required database tables on first call
3. Returns the cached instance on subsequent calls 3. Returns the cached instance on subsequent calls
The underlying connection pool handles reconnection automatically,
so a stale/closed connection will not cause OperationalError.
Returns: Returns:
AsyncPostgresSaver: The configured checkpointer instance AsyncPostgresSaver: The configured checkpointer instance
""" """
global _checkpointer, _checkpointer_context, _checkpointer_initialized global _checkpointer, _checkpointer_initialized
if _checkpointer is None: if _checkpointer is None:
conn_string = get_postgres_connection_string() _checkpointer = await _create_checkpointer()
# from_conn_string returns an async context manager _checkpointer_initialized = False
# We need to enter the context to get the actual checkpointer
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
_checkpointer = await _checkpointer_context.__aenter__()
# Setup tables on first call (idempotent) # Setup tables on first call (idempotent)
if not _checkpointer_initialized: if not _checkpointer_initialized:
@ -75,20 +124,21 @@ async def setup_checkpointer_tables() -> None:
tables exist before any agent calls. tables exist before any agent calls.
""" """
await get_checkpointer() await get_checkpointer()
print("[Checkpointer] PostgreSQL checkpoint tables ready") logger.info("[Checkpointer] PostgreSQL checkpoint tables ready")
async def close_checkpointer() -> None: async def close_checkpointer() -> None:
""" """
Close the checkpointer connection. Close the checkpointer connection pool.
This should be called during application shutdown. This should be called during application shutdown.
""" """
global _checkpointer, _checkpointer_context, _checkpointer_initialized global _checkpointer, _connection_pool, _checkpointer_initialized
if _checkpointer_context is not None: if _connection_pool is not None:
await _checkpointer_context.__aexit__(None, None, None) await _connection_pool.close()
_checkpointer = None logger.info("[Checkpointer] PostgreSQL connection pool closed")
_checkpointer_context = None
_checkpointer_initialized = False _checkpointer = None
print("[Checkpointer] PostgreSQL connection closed") _connection_pool = None
_checkpointer_initialized = False