mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
feat: message history and PostgreSQL checkpointer integration
This commit is contained in:
parent
3906ba52e0
commit
73f0f772a8
11 changed files with 434 additions and 115 deletions
|
|
@ -10,6 +10,7 @@ from collections.abc import Sequence
|
|||
from deepagents import create_deep_agent
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
|
|
@ -27,6 +28,7 @@ def create_surfsense_deep_agent(
|
|||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connector_service: ConnectorService,
|
||||
checkpointer: Checkpointer,
|
||||
user_instructions: str | None = None,
|
||||
enable_citations: bool = True,
|
||||
additional_tools: Sequence[BaseTool] | None = None,
|
||||
|
|
@ -39,6 +41,8 @@ def create_surfsense_deep_agent(
|
|||
search_space_id: The user's search space ID
|
||||
db_session: Database session
|
||||
connector_service: Initialized connector service
|
||||
checkpointer: LangGraph checkpointer for conversation state persistence.
|
||||
Use AsyncPostgresSaver for production or MemorySaver for testing.
|
||||
user_instructions: Optional user instructions to inject into the system prompt.
|
||||
These will be added to the system prompt to customize agent behavior.
|
||||
enable_citations: Whether to include citation instructions in the system prompt (default: True).
|
||||
|
|
@ -61,7 +65,7 @@ def create_surfsense_deep_agent(
|
|||
if additional_tools:
|
||||
tools.extend(additional_tools)
|
||||
|
||||
# Create the deep agent with user-configurable system prompt
|
||||
# Create the deep agent with user-configurable system prompt and checkpointer
|
||||
agent = create_deep_agent(
|
||||
model=llm,
|
||||
tools=tools,
|
||||
|
|
@ -70,6 +74,7 @@ def create_surfsense_deep_agent(
|
|||
enable_citations=enable_citations,
|
||||
),
|
||||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=checkpointer, # Enable conversation memory via thread_id
|
||||
)
|
||||
|
||||
return agent
|
||||
|
|
|
|||
95
surfsense_backend/app/agents/new_chat/checkpointer.py
Normal file
95
surfsense_backend/app/agents/new_chat/checkpointer.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
PostgreSQL-based checkpointer for LangGraph agents.
|
||||
|
||||
This module provides a persistent checkpointer using AsyncPostgresSaver
|
||||
that stores conversation state in the PostgreSQL database.
|
||||
"""
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
from app.config import config
|
||||
|
||||
# Global checkpointer instance (initialized lazily)
|
||||
_checkpointer: AsyncPostgresSaver | None = None
|
||||
_checkpointer_context = None # Store the context manager for cleanup
|
||||
_checkpointer_initialized: bool = False
|
||||
|
||||
|
||||
def get_postgres_connection_string() -> str:
|
||||
"""
|
||||
Convert the async DATABASE_URL to a sync postgres connection string for psycopg3.
|
||||
|
||||
The DATABASE_URL is typically in format:
|
||||
postgresql+asyncpg://user:pass@host:port/dbname
|
||||
|
||||
We need to convert it to:
|
||||
postgresql://user:pass@host:port/dbname
|
||||
"""
|
||||
db_url = config.DATABASE_URL
|
||||
|
||||
# Handle asyncpg driver prefix
|
||||
if db_url.startswith("postgresql+asyncpg://"):
|
||||
return db_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
# Handle other async prefixes
|
||||
if "+asyncpg" in db_url:
|
||||
return db_url.replace("+asyncpg", "")
|
||||
|
||||
return db_url
|
||||
|
||||
|
||||
async def get_checkpointer() -> AsyncPostgresSaver:
|
||||
"""
|
||||
Get or create the global AsyncPostgresSaver instance.
|
||||
|
||||
This function:
|
||||
1. Creates the checkpointer if it doesn't exist
|
||||
2. Sets up the required database tables on first call
|
||||
3. Returns the cached instance on subsequent calls
|
||||
|
||||
Returns:
|
||||
AsyncPostgresSaver: The configured checkpointer instance
|
||||
"""
|
||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
||||
|
||||
if _checkpointer is None:
|
||||
conn_string = get_postgres_connection_string()
|
||||
# from_conn_string returns an async context manager
|
||||
# We need to enter the context to get the actual checkpointer
|
||||
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
|
||||
_checkpointer = await _checkpointer_context.__aenter__()
|
||||
|
||||
# Setup tables on first call (idempotent)
|
||||
if not _checkpointer_initialized:
|
||||
await _checkpointer.setup()
|
||||
_checkpointer_initialized = True
|
||||
|
||||
return _checkpointer
|
||||
|
||||
|
||||
async def setup_checkpointer_tables() -> None:
|
||||
"""
|
||||
Explicitly setup the checkpointer tables.
|
||||
|
||||
This can be called during application startup to ensure
|
||||
tables exist before any agent calls.
|
||||
"""
|
||||
await get_checkpointer()
|
||||
print("[Checkpointer] PostgreSQL checkpoint tables ready")
|
||||
|
||||
|
||||
async def close_checkpointer() -> None:
|
||||
"""
|
||||
Close the checkpointer connection.
|
||||
|
||||
This should be called during application shutdown.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_context, _checkpointer_initialized
|
||||
|
||||
if _checkpointer_context is not None:
|
||||
await _checkpointer_context.__aexit__(None, None, None)
|
||||
_checkpointer = None
|
||||
_checkpointer_context = None
|
||||
_checkpointer_initialized = False
|
||||
print("[Checkpointer] PostgreSQL connection closed")
|
||||
|
||||
|
|
@ -5,6 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
|
||||
from app.agents.new_chat.checkpointer import close_checkpointer, setup_checkpointer_tables
|
||||
from app.config import config
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
|
|
@ -16,7 +17,11 @@ from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
|||
async def lifespan(app: FastAPI):
|
||||
# Not needed if you setup a migration system like Alembic
|
||||
await create_db_and_tables()
|
||||
# Setup LangGraph checkpointer tables for conversation persistence
|
||||
await setup_checkpointer_tables()
|
||||
yield
|
||||
# Cleanup: close checkpointer connection on shutdown
|
||||
await close_checkpointer()
|
||||
|
||||
|
||||
def registration_allowed():
|
||||
|
|
|
|||
|
|
@ -226,6 +226,7 @@ async def handle_new_chat(
|
|||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
llm_config_id=llm_config_id,
|
||||
messages=request.messages, # Pass message history from frontend
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,12 +48,20 @@ class AISDKChatRequest(BaseModel):
|
|||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""A single message in the chat history."""
|
||||
|
||||
role: str # "user" or "assistant"
|
||||
content: str
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the new deep agent chat endpoint."""
|
||||
|
||||
chat_id: int
|
||||
user_query: str
|
||||
search_space_id: int
|
||||
messages: list[ChatMessage] | None = None # Optional chat history from frontend
|
||||
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ Data Stream Protocol (SSE format).
|
|||
from collections.abc import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.chat_deepagent import (
|
||||
create_surfsense_deep_agent,
|
||||
)
|
||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
|
||||
from app.schemas.chats import ChatMessage
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
|
|
@ -26,13 +26,14 @@ async def stream_new_chat(
|
|||
chat_id: int,
|
||||
session: AsyncSession,
|
||||
llm_config_id: int = -1,
|
||||
messages: list[ChatMessage] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
||||
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
|
||||
The chat_id is used as LangGraph's thread_id for memory/checkpointing,
|
||||
so chat history is automatically managed by LangGraph.
|
||||
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
|
||||
Message history can be passed from the frontend for context.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
|
|
@ -41,6 +42,7 @@ async def stream_new_chat(
|
|||
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
||||
session: The database session
|
||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||
messages: Optional chat history from frontend (list of ChatMessage)
|
||||
|
||||
Yields:
|
||||
str: SSE formatted response strings
|
||||
|
|
@ -73,18 +75,36 @@ async def stream_new_chat(
|
|||
# Create connector service
|
||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||
|
||||
# Create the deep agent
|
||||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||
checkpointer = await get_checkpointer()
|
||||
|
||||
# Create the deep agent with checkpointer
|
||||
agent = create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
connector_service=connector_service,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
# Build input with just the current user query
|
||||
# Chat history is managed by LangGraph via thread_id
|
||||
# Build input with message history from frontend
|
||||
langchain_messages = []
|
||||
|
||||
# if messages:
|
||||
# # Convert frontend messages to LangChain format
|
||||
# for msg in messages:
|
||||
# if msg.role == "user":
|
||||
# langchain_messages.append(HumanMessage(content=msg.content))
|
||||
# elif msg.role == "assistant":
|
||||
# langchain_messages.append(AIMessage(content=msg.content))
|
||||
# else:
|
||||
# Fallback: just use the current user query
|
||||
langchain_messages.append(HumanMessage(content=user_query))
|
||||
|
||||
input_state = {
|
||||
"messages": [HumanMessage(content=user_query)],
|
||||
# Lets not pass this message atm because we are using the checkpointer to manage the conversation history
|
||||
# We will use this to simulate group chat functionality in the future
|
||||
"messages": langchain_messages,
|
||||
"search_space_id": search_space_id,
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue