mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-30 21:59:46 +02:00
feat: implement connection pooling for AsyncPostgresSaver in checkpointer
This commit is contained in:
parent
f85adefe5e
commit
af16b6656c
1 changed files with 67 additions and 17 deletions
|
|
@ -3,15 +3,25 @@ PostgreSQL-based checkpointer for LangGraph agents.
|
||||||
|
|
||||||
This module provides a persistent checkpointer using AsyncPostgresSaver
|
This module provides a persistent checkpointer using AsyncPostgresSaver
|
||||||
that stores conversation state in the PostgreSQL database.
|
that stores conversation state in the PostgreSQL database.
|
||||||
|
|
||||||
|
Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle
|
||||||
|
connection lifecycle, health checks, and automatic reconnection,
|
||||||
|
preventing 'the connection is closed' errors in long-running deployments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from psycopg.rows import dict_row
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Global checkpointer instance (initialized lazily)
|
# Global checkpointer instance (initialized lazily)
|
||||||
_checkpointer: AsyncPostgresSaver | None = None
|
_checkpointer: AsyncPostgresSaver | None = None
|
||||||
_checkpointer_context = None # Store the context manager for cleanup
|
_connection_pool: AsyncConnectionPool | None = None
|
||||||
_checkpointer_initialized: bool = False
|
_checkpointer_initialized: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,26 +48,65 @@ def get_postgres_connection_string() -> str:
|
||||||
return db_url
|
return db_url
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_checkpointer() -> AsyncPostgresSaver:
|
||||||
|
"""
|
||||||
|
Create a new AsyncPostgresSaver backed by a connection pool.
|
||||||
|
|
||||||
|
The connection pool automatically handles:
|
||||||
|
- Connection health checks before use
|
||||||
|
- Reconnection when connections die (idle timeout, DB restart, etc.)
|
||||||
|
- Connection lifecycle management (max_lifetime, max_idle)
|
||||||
|
"""
|
||||||
|
global _connection_pool
|
||||||
|
|
||||||
|
conn_string = get_postgres_connection_string()
|
||||||
|
|
||||||
|
_connection_pool = AsyncConnectionPool(
|
||||||
|
conninfo=conn_string,
|
||||||
|
min_size=2,
|
||||||
|
max_size=10,
|
||||||
|
# Connections are recycled after 30 minutes to avoid stale connections
|
||||||
|
max_lifetime=1800,
|
||||||
|
# Idle connections are closed after 5 minutes
|
||||||
|
max_idle=300,
|
||||||
|
open=False,
|
||||||
|
# Connection kwargs required by AsyncPostgresSaver:
|
||||||
|
# - autocommit: required for .setup() to commit checkpoint tables
|
||||||
|
# - prepare_threshold: disable prepared statements for compatibility
|
||||||
|
# - row_factory: checkpointer accesses rows as dicts (row["column"])
|
||||||
|
kwargs={
|
||||||
|
"autocommit": True,
|
||||||
|
"prepare_threshold": 0,
|
||||||
|
"row_factory": dict_row,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await _connection_pool.open(wait=True)
|
||||||
|
|
||||||
|
checkpointer = AsyncPostgresSaver(conn=_connection_pool)
|
||||||
|
logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool")
|
||||||
|
return checkpointer
|
||||||
|
|
||||||
|
|
||||||
async def get_checkpointer() -> AsyncPostgresSaver:
|
async def get_checkpointer() -> AsyncPostgresSaver:
|
||||||
"""
|
"""
|
||||||
Get or create the global AsyncPostgresSaver instance.
|
Get or create the global AsyncPostgresSaver instance.
|
||||||
|
|
||||||
This function:
|
This function:
|
||||||
1. Creates the checkpointer if it doesn't exist
|
1. Creates the checkpointer with a connection pool if it doesn't exist
|
||||||
2. Sets up the required database tables on first call
|
2. Sets up the required database tables on first call
|
||||||
3. Returns the cached instance on subsequent calls
|
3. Returns the cached instance on subsequent calls
|
||||||
|
|
||||||
|
The underlying connection pool handles reconnection automatically,
|
||||||
|
so a stale/closed connection will not cause OperationalError.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncPostgresSaver: The configured checkpointer instance
|
AsyncPostgresSaver: The configured checkpointer instance
|
||||||
"""
|
"""
|
||||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
global _checkpointer, _checkpointer_initialized
|
||||||
|
|
||||||
if _checkpointer is None:
|
if _checkpointer is None:
|
||||||
conn_string = get_postgres_connection_string()
|
_checkpointer = await _create_checkpointer()
|
||||||
# from_conn_string returns an async context manager
|
_checkpointer_initialized = False
|
||||||
# We need to enter the context to get the actual checkpointer
|
|
||||||
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
|
|
||||||
_checkpointer = await _checkpointer_context.__aenter__()
|
|
||||||
|
|
||||||
# Setup tables on first call (idempotent)
|
# Setup tables on first call (idempotent)
|
||||||
if not _checkpointer_initialized:
|
if not _checkpointer_initialized:
|
||||||
|
|
@ -75,20 +124,21 @@ async def setup_checkpointer_tables() -> None:
|
||||||
tables exist before any agent calls.
|
tables exist before any agent calls.
|
||||||
"""
|
"""
|
||||||
await get_checkpointer()
|
await get_checkpointer()
|
||||||
print("[Checkpointer] PostgreSQL checkpoint tables ready")
|
logger.info("[Checkpointer] PostgreSQL checkpoint tables ready")
|
||||||
|
|
||||||
|
|
||||||
async def close_checkpointer() -> None:
|
async def close_checkpointer() -> None:
|
||||||
"""
|
"""
|
||||||
Close the checkpointer connection.
|
Close the checkpointer connection pool.
|
||||||
|
|
||||||
This should be called during application shutdown.
|
This should be called during application shutdown.
|
||||||
"""
|
"""
|
||||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
global _checkpointer, _connection_pool, _checkpointer_initialized
|
||||||
|
|
||||||
if _checkpointer_context is not None:
|
if _connection_pool is not None:
|
||||||
await _checkpointer_context.__aexit__(None, None, None)
|
await _connection_pool.close()
|
||||||
_checkpointer = None
|
logger.info("[Checkpointer] PostgreSQL connection pool closed")
|
||||||
_checkpointer_context = None
|
|
||||||
_checkpointer_initialized = False
|
_checkpointer = None
|
||||||
print("[Checkpointer] PostgreSQL connection closed")
|
_connection_pool = None
|
||||||
|
_checkpointer_initialized = False
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue