feat: message history and PostgreSQL checkpointer integration

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-12-21 03:30:10 -08:00
parent 3906ba52e0
commit 73f0f772a8
11 changed files with 434 additions and 115 deletions

View file

@ -10,6 +10,7 @@ from collections.abc import Sequence
from deepagents import create_deep_agent
from langchain_core.tools import BaseTool
from langchain_litellm import ChatLiteLLM
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.context import SurfSenseContextSchema
@ -27,6 +28,7 @@ def create_surfsense_deep_agent(
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
checkpointer: Checkpointer,
user_instructions: str | None = None,
enable_citations: bool = True,
additional_tools: Sequence[BaseTool] | None = None,
@ -39,6 +41,8 @@ def create_surfsense_deep_agent(
search_space_id: The user's search space ID
db_session: Database session
connector_service: Initialized connector service
checkpointer: LangGraph checkpointer for conversation state persistence.
Use AsyncPostgresSaver for production or MemorySaver for testing.
user_instructions: Optional user instructions to inject into the system prompt.
These will be added to the system prompt to customize agent behavior.
enable_citations: Whether to include citation instructions in the system prompt (default: True).
@ -61,7 +65,7 @@ def create_surfsense_deep_agent(
if additional_tools:
tools.extend(additional_tools)
# Create the deep agent with user-configurable system prompt
# Create the deep agent with user-configurable system prompt and checkpointer
agent = create_deep_agent(
model=llm,
tools=tools,
@ -70,6 +74,7 @@ def create_surfsense_deep_agent(
enable_citations=enable_citations,
),
context_schema=SurfSenseContextSchema,
checkpointer=checkpointer, # Enable conversation memory via thread_id
)
return agent

View file

@ -0,0 +1,95 @@
"""
PostgreSQL-based checkpointer for LangGraph agents.
This module provides a persistent checkpointer using AsyncPostgresSaver
that stores conversation state in the PostgreSQL database.
"""
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.config import config
# Global checkpointer instance (initialized lazily)
_checkpointer: AsyncPostgresSaver | None = None
_checkpointer_context = None # Store the context manager for cleanup
_checkpointer_initialized: bool = False
def get_postgres_connection_string() -> str:
"""
Convert the async DATABASE_URL to a sync postgres connection string for psycopg3.
The DATABASE_URL is typically in format:
postgresql+asyncpg://user:pass@host:port/dbname
We need to convert it to:
postgresql://user:pass@host:port/dbname
"""
db_url = config.DATABASE_URL
# Handle asyncpg driver prefix
if db_url.startswith("postgresql+asyncpg://"):
return db_url.replace("postgresql+asyncpg://", "postgresql://")
# Handle other async prefixes
if "+asyncpg" in db_url:
return db_url.replace("+asyncpg", "")
return db_url
async def get_checkpointer() -> AsyncPostgresSaver:
"""
Get or create the global AsyncPostgresSaver instance.
This function:
1. Creates the checkpointer if it doesn't exist
2. Sets up the required database tables on first call
3. Returns the cached instance on subsequent calls
Returns:
AsyncPostgresSaver: The configured checkpointer instance
"""
global _checkpointer, _checkpointer_context, _checkpointer_initialized
if _checkpointer is None:
conn_string = get_postgres_connection_string()
# from_conn_string returns an async context manager
# We need to enter the context to get the actual checkpointer
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
_checkpointer = await _checkpointer_context.__aenter__()
# Setup tables on first call (idempotent)
if not _checkpointer_initialized:
await _checkpointer.setup()
_checkpointer_initialized = True
return _checkpointer
async def setup_checkpointer_tables() -> None:
"""
Explicitly setup the checkpointer tables.
This can be called during application startup to ensure
tables exist before any agent calls.
"""
await get_checkpointer()
print("[Checkpointer] PostgreSQL checkpoint tables ready")
async def close_checkpointer() -> None:
"""
Close the checkpointer connection.
This should be called during application shutdown.
"""
global _checkpointer, _checkpointer_context, _checkpointer_initialized
if _checkpointer_context is not None:
await _checkpointer_context.__aexit__(None, None, None)
_checkpointer = None
_checkpointer_context = None
_checkpointer_initialized = False
print("[Checkpointer] PostgreSQL connection closed")

View file

@ -5,6 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from app.agents.new_chat.checkpointer import close_checkpointer, setup_checkpointer_tables
from app.config import config
from app.db import User, create_db_and_tables, get_async_session
from app.routes import router as crud_router
@ -16,7 +17,11 @@ from app.users import SECRET, auth_backend, current_active_user, fastapi_users
async def lifespan(app: FastAPI):
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
# Setup LangGraph checkpointer tables for conversation persistence
await setup_checkpointer_tables()
yield
# Cleanup: close checkpointer connection on shutdown
await close_checkpointer()
def registration_allowed():

View file

@ -226,6 +226,7 @@ async def handle_new_chat(
chat_id=request.chat_id,
session=session,
llm_config_id=llm_config_id,
messages=request.messages, # Pass message history from frontend
),
media_type="text/event-stream",
)

View file

@ -48,12 +48,20 @@ class AISDKChatRequest(BaseModel):
data: dict[str, Any] | None = None
class ChatMessage(BaseModel):
"""A single message in the chat history."""
role: str # "user" or "assistant"
content: str
class NewChatRequest(BaseModel):
"""Request schema for the new deep agent chat endpoint."""
chat_id: int
user_query: str
search_space_id: int
messages: list[ChatMessage] | None = None # Optional chat history from frontend
class ChatCreate(ChatBase):

View file

@ -8,13 +8,13 @@ Data Stream Protocol (SSE format).
from collections.abc import AsyncGenerator
from uuid import UUID
from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessage, HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.chat_deepagent import (
create_surfsense_deep_agent,
)
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
from app.schemas.chats import ChatMessage
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
@ -26,13 +26,14 @@ async def stream_new_chat(
chat_id: int,
session: AsyncSession,
llm_config_id: int = -1,
messages: list[ChatMessage] | None = None,
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing,
so chat history is automatically managed by LangGraph.
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
Message history can be passed from the frontend for context.
Args:
user_query: The user's query
@ -41,6 +42,7 @@ async def stream_new_chat(
chat_id: The chat ID (used as LangGraph thread_id for memory)
session: The database session
llm_config_id: The LLM configuration ID (default: -1 for first global config)
messages: Optional chat history from frontend (list of ChatMessage)
Yields:
str: SSE formatted response strings
@ -73,18 +75,36 @@ async def stream_new_chat(
# Create connector service
connector_service = ConnectorService(session, search_space_id=search_space_id)
# Create the deep agent
# Get the PostgreSQL checkpointer for persistent conversation memory
checkpointer = await get_checkpointer()
# Create the deep agent with checkpointer
agent = create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
)
# Build input with just the current user query
# Chat history is managed by LangGraph via thread_id
# Build input with message history from frontend
langchain_messages = []
# if messages:
# # Convert frontend messages to LangChain format
# for msg in messages:
# if msg.role == "user":
# langchain_messages.append(HumanMessage(content=msg.content))
# elif msg.role == "assistant":
# langchain_messages.append(AIMessage(content=msg.content))
# else:
# Fallback: just use the current user query
langchain_messages.append(HumanMessage(content=user_query))
input_state = {
"messages": [HumanMessage(content=user_query)],
# Lets not pass this message atm because we are using the checkpointer to manage the conversation history
# We will use this to simulate group chat functionality in the future
"messages": langchain_messages,
"search_space_id": search_space_id,
}