feat: implement connection pooling for AsyncPostgresSaver in checkpointer

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-05 17:32:43 -08:00
parent f85adefe5e
commit af16b6656c

View file

@ -3,15 +3,25 @@ PostgreSQL-based checkpointer for LangGraph agents.
This module provides a persistent checkpointer using AsyncPostgresSaver
that stores conversation state in the PostgreSQL database.
Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle
connection lifecycle, health checks, and automatic reconnection,
preventing 'the connection is closed' errors in long-running deployments.
"""
import logging
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from app.config import config
logger = logging.getLogger(__name__)
# Global checkpointer instance (initialized lazily)
_checkpointer: AsyncPostgresSaver | None = None
_checkpointer_context = None # Store the context manager for cleanup
_connection_pool: AsyncConnectionPool | None = None
_checkpointer_initialized: bool = False
@ -38,26 +48,65 @@ def get_postgres_connection_string() -> str:
return db_url
async def _create_checkpointer() -> AsyncPostgresSaver:
"""
Create a new AsyncPostgresSaver backed by a connection pool.
The connection pool automatically handles:
- Connection health checks before use
- Reconnection when connections die (idle timeout, DB restart, etc.)
- Connection lifecycle management (max_lifetime, max_idle)
"""
global _connection_pool
conn_string = get_postgres_connection_string()
_connection_pool = AsyncConnectionPool(
conninfo=conn_string,
min_size=2,
max_size=10,
# Connections are recycled after 30 minutes to avoid stale connections
max_lifetime=1800,
# Idle connections are closed after 5 minutes
max_idle=300,
open=False,
# Connection kwargs required by AsyncPostgresSaver:
# - autocommit: required for .setup() to commit checkpoint tables
# - prepare_threshold: disable prepared statements for compatibility
# - row_factory: checkpointer accesses rows as dicts (row["column"])
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
},
)
await _connection_pool.open(wait=True)
checkpointer = AsyncPostgresSaver(conn=_connection_pool)
logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool")
return checkpointer
async def get_checkpointer() -> AsyncPostgresSaver:
"""
Get or create the global AsyncPostgresSaver instance.
This function:
1. Creates the checkpointer if it doesn't exist
1. Creates the checkpointer with a connection pool if it doesn't exist
2. Sets up the required database tables on first call
3. Returns the cached instance on subsequent calls
The underlying connection pool handles reconnection automatically,
so a stale/closed connection will not cause OperationalError.
Returns:
AsyncPostgresSaver: The configured checkpointer instance
"""
global _checkpointer, _checkpointer_context, _checkpointer_initialized
global _checkpointer, _checkpointer_initialized
if _checkpointer is None:
conn_string = get_postgres_connection_string()
# from_conn_string returns an async context manager
# We need to enter the context to get the actual checkpointer
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
_checkpointer = await _checkpointer_context.__aenter__()
_checkpointer = await _create_checkpointer()
_checkpointer_initialized = False
# Setup tables on first call (idempotent)
if not _checkpointer_initialized:
@ -75,20 +124,21 @@ async def setup_checkpointer_tables() -> None:
tables exist before any agent calls.
"""
await get_checkpointer()
print("[Checkpointer] PostgreSQL checkpoint tables ready")
logger.info("[Checkpointer] PostgreSQL checkpoint tables ready")
async def close_checkpointer() -> None:
"""
Close the checkpointer connection.
Close the checkpointer connection pool.
This should be called during application shutdown.
"""
global _checkpointer, _checkpointer_context, _checkpointer_initialized
global _checkpointer, _connection_pool, _checkpointer_initialized
if _checkpointer_context is not None:
await _checkpointer_context.__aexit__(None, None, None)
_checkpointer = None
_checkpointer_context = None
_checkpointer_initialized = False
print("[Checkpointer] PostgreSQL connection closed")
if _connection_pool is not None:
await _connection_pool.close()
logger.info("[Checkpointer] PostgreSQL connection pool closed")
_checkpointer = None
_connection_pool = None
_checkpointer_initialized = False