refactor(agents): move checkpointer + mention_resolver to app/agents/shared (slice 5b)

Two independent leaf modules (no intra-new_chat deps, no frozen importer),
consumed only by flows/routes/tests. Flipped 8 importers across both the
dotted-path and module-style (from app.agents.new_chat import mention_resolver)
forms. No shims needed.
This commit is contained in:
CREDO23 2026-06-04 12:52:54 +02:00
parent dcdf8f776b
commit 6f488d9564
10 changed files with 11 additions and 11 deletions

View file

@ -0,0 +1,144 @@
"""
PostgreSQL-based checkpointer for LangGraph agents.
This module provides a persistent checkpointer using AsyncPostgresSaver
that stores conversation state in the PostgreSQL database.
Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle
connection lifecycle, health checks, and automatic reconnection,
preventing 'the connection is closed' errors in long-running deployments.
"""
import logging
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from app.config import config
logger = logging.getLogger(__name__)
# Global checkpointer instance (initialized lazily)
_checkpointer: AsyncPostgresSaver | None = None
_connection_pool: AsyncConnectionPool | None = None
_checkpointer_initialized: bool = False
def get_postgres_connection_string() -> str:
"""
Convert the async DATABASE_URL to a sync postgres connection string for psycopg3.
The DATABASE_URL is typically in format:
postgresql+asyncpg://user:pass@host:port/dbname
We need to convert it to:
postgresql://user:pass@host:port/dbname
"""
db_url = config.DATABASE_URL
# Handle asyncpg driver prefix
if db_url.startswith("postgresql+asyncpg://"):
return db_url.replace("postgresql+asyncpg://", "postgresql://")
# Handle other async prefixes
if "+asyncpg" in db_url:
return db_url.replace("+asyncpg", "")
return db_url
async def _create_checkpointer() -> AsyncPostgresSaver:
"""
Create a new AsyncPostgresSaver backed by a connection pool.
The connection pool automatically handles:
- Connection health checks before use
- Reconnection when connections die (idle timeout, DB restart, etc.)
- Connection lifecycle management (max_lifetime, max_idle)
"""
global _connection_pool
conn_string = get_postgres_connection_string()
_connection_pool = AsyncConnectionPool(
conninfo=conn_string,
min_size=2,
max_size=10,
# Connections are recycled after 30 minutes to avoid stale connections
max_lifetime=1800,
# Idle connections are closed after 5 minutes
max_idle=300,
open=False,
# Connection kwargs required by AsyncPostgresSaver:
# - autocommit: required for .setup() to commit checkpoint tables
# - prepare_threshold: disable prepared statements for compatibility
# - row_factory: checkpointer accesses rows as dicts (row["column"])
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
},
)
await _connection_pool.open(wait=True)
checkpointer = AsyncPostgresSaver(conn=_connection_pool)
logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool")
return checkpointer
async def get_checkpointer() -> AsyncPostgresSaver:
"""
Get or create the global AsyncPostgresSaver instance.
This function:
1. Creates the checkpointer with a connection pool if it doesn't exist
2. Sets up the required database tables on first call
3. Returns the cached instance on subsequent calls
The underlying connection pool handles reconnection automatically,
so a stale/closed connection will not cause OperationalError.
Returns:
AsyncPostgresSaver: The configured checkpointer instance
"""
global _checkpointer, _checkpointer_initialized
if _checkpointer is None:
_checkpointer = await _create_checkpointer()
_checkpointer_initialized = False
# Setup tables on first call (idempotent)
if not _checkpointer_initialized:
await _checkpointer.setup()
_checkpointer_initialized = True
return _checkpointer
async def setup_checkpointer_tables() -> None:
"""
Explicitly setup the checkpointer tables.
This can be called during application startup to ensure
tables exist before any agent calls.
"""
await get_checkpointer()
logger.info("[Checkpointer] PostgreSQL checkpoint tables ready")
async def close_checkpointer() -> None:
"""
Close the checkpointer connection pool.
This should be called during application shutdown.
"""
global _checkpointer, _connection_pool, _checkpointer_initialized
if _connection_pool is not None:
await _connection_pool.close()
logger.info("[Checkpointer] PostgreSQL connection pool closed")
_checkpointer = None
_connection_pool = None
_checkpointer_initialized = False

View file

@ -0,0 +1,277 @@
"""Resolve @-mention chips to canonical virtual paths and substitute the
user-visible ``@title`` tokens with backtick-wrapped paths in the prompt
the agent sees.
The frontend's mention seam is a single discriminated-union list of
``{kind: "doc" | "folder", id, title, document_type?}`` chips (see
``surfsense_web/atoms/chat/mentioned-documents.atom.ts``). When a turn
reaches the backend stream task we have three needs that this module
centralises:
1. Map each chip to its canonical virtual path
(``/documents/.../file.xml`` for docs, ``/documents/MyFolder/`` for
folders) so the agent sees concrete filesystem locations instead of
ambiguous ``@``-titles.
2. Substitute ``@title`` tokens in the user-typed text with backtick-
wrapped paths so the path becomes part of the ``HumanMessage`` body
the LLM consumes without rewriting the persisted user message
text (which keeps ``@title`` so chip rendering on reload is
unchanged).
3. Surface the resolved id sets (docs + folders) to the priority
middleware so it can render ``[USER-MENTIONED]`` priority entries
without re-doing path resolution.
This is intentionally one module see the architectural note in
``mention-paths-and-folders`` plan: previously the doc-resolution lived
inline in ``stream_new_chat`` and the folder mention had no resolution
at all. Centralising both behind a single ``resolve_mentions`` call
turns a leaky multi-field seam into a single deeper interface.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.path_resolver import (
DOCUMENTS_ROOT,
build_path_index,
doc_to_virtual_path,
)
from app.db import Document, Folder
from app.schemas.new_chat import MentionedDocumentInfo
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ResolvedMention:
"""Canonical view of a single @-mention chip.
``virtual_path`` is the path the agent will see (no trailing slash
for documents, trailing ``/`` for folders to match the convention
used by ``KnowledgeTreeMiddleware``).
"""
kind: str # "doc" | "folder"
id: int
title: str
virtual_path: str
@dataclass
class ResolvedMentionSet:
"""Aggregate result of resolving a turn's mention chips.
``token_to_path`` maps ``@title`` (the literal token the user typed
and the editor emitted) to the canonical virtual path for that
chip. It is produced longest-token-first so substitution mirrors
``parseMentionSegments`` on the frontend (a longer title like
``@Project Roadmap`` is never shadowed by a shorter prefix
``@Project``).
``mentioned_document_ids`` is an ordered, deduped list consumed by
the priority middleware downstream see
``KnowledgePriorityMiddleware._compute_priority_paths``.
"""
mentions: list[ResolvedMention] = field(default_factory=list)
token_to_path: list[tuple[str, str]] = field(default_factory=list)
mentioned_document_ids: list[int] = field(default_factory=list)
mentioned_folder_ids: list[int] = field(default_factory=list)
def _folder_virtual_path(folder_id: int, folder_paths: dict[int, str]) -> str:
"""Return ``/documents/Folder/Sub/`` for a folder id.
Falls back to the documents root when the folder is missing from
the index (deleted or in a different search space). Trailing slash
matches ``KnowledgeTreeMiddleware`` (``/documents/MyFolder/``) so
the agent's ``ls`` can dispatch on it as a directory.
"""
base = folder_paths.get(folder_id, DOCUMENTS_ROOT)
return f"{base}/" if not base.endswith("/") else base
async def resolve_mentions(
session: AsyncSession,
*,
search_space_id: int,
mentioned_documents: list[MentionedDocumentInfo] | None,
mentioned_document_ids: list[int] | None = None,
mentioned_folder_ids: list[int] | None = None,
) -> ResolvedMentionSet:
"""Resolve every @-mention chip on a turn into virtual paths.
The function takes both the ``mentioned_documents`` discriminated
list (chip metadata used for substitution + persistence) and the
parallel id arrays (``mentioned_document_ids``,
``mentioned_folder_ids``) for two reasons:
* Legacy clients that haven't migrated to the unified chip list
still send the id arrays we treat the union as authoritative.
* The id arrays are the canonical input to
``KnowledgePriorityMiddleware`` (via ``SurfSenseContextSchema``);
returning the deduped, validated lists lets the route forward
them unchanged.
Resolution is best-effort: a chip whose id no longer exists (e.g.
document was deleted between mention and submit) is silently
dropped. The agent still sees the user's original text, just
without a backtick-path substitution for that chip.
"""
chip_doc_ids: list[int] = []
chip_folder_ids: list[int] = []
chip_titles_by_id: dict[tuple[str, int], str] = {}
if mentioned_documents:
for chip in mentioned_documents:
kind = chip.kind
if kind == "folder":
chip_folder_ids.append(chip.id)
elif kind == "doc":
chip_doc_ids.append(chip.id)
chip_titles_by_id[(kind, chip.id)] = chip.title
doc_id_pool: list[int] = list(
dict.fromkeys(
[
*(mentioned_document_ids or []),
*chip_doc_ids,
]
)
)
folder_id_pool: list[int] = list(
dict.fromkeys([*(mentioned_folder_ids or []), *chip_folder_ids])
)
if not doc_id_pool and not folder_id_pool:
return ResolvedMentionSet()
index = await build_path_index(session, search_space_id)
doc_rows: dict[int, Document] = {}
if doc_id_pool:
result = await session.execute(
select(Document).where(
Document.search_space_id == search_space_id,
Document.id.in_(doc_id_pool),
)
)
for row in result.scalars().all():
doc_rows[row.id] = row
folder_rows: dict[int, Folder] = {}
if folder_id_pool:
result = await session.execute(
select(Folder).where(
Folder.search_space_id == search_space_id,
Folder.id.in_(folder_id_pool),
)
)
for row in result.scalars().all():
folder_rows[row.id] = row
resolved: list[ResolvedMention] = []
accepted_doc_ids: list[int] = []
accepted_folder_ids: list[int] = []
for doc_id in doc_id_pool:
row = doc_rows.get(doc_id)
if row is None:
logger.debug(
"mention_resolver: dropping doc id=%s (not found in space=%s)",
doc_id,
search_space_id,
)
continue
title = chip_titles_by_id.get(("doc", doc_id), str(row.title or ""))
path = doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
resolved.append(
ResolvedMention(kind="doc", id=row.id, title=title, virtual_path=path)
)
accepted_doc_ids.append(row.id)
for folder_id in folder_id_pool:
row = folder_rows.get(folder_id)
if row is None:
logger.debug(
"mention_resolver: dropping folder id=%s (not found in space=%s)",
folder_id,
search_space_id,
)
continue
title = chip_titles_by_id.get(("folder", folder_id), str(row.name or ""))
path = _folder_virtual_path(row.id, index.folder_paths)
resolved.append(
ResolvedMention(kind="folder", id=row.id, title=title, virtual_path=path)
)
accepted_folder_ids.append(row.id)
token_to_path: list[tuple[str, str]] = []
seen_tokens: set[str] = set()
for mention in resolved:
if not mention.title:
continue
token = f"@{mention.title}"
if token in seen_tokens:
continue
seen_tokens.add(token)
token_to_path.append((token, mention.virtual_path))
token_to_path.sort(key=lambda pair: len(pair[0]), reverse=True)
return ResolvedMentionSet(
mentions=resolved,
token_to_path=token_to_path,
mentioned_document_ids=accepted_doc_ids,
mentioned_folder_ids=accepted_folder_ids,
)
def substitute_in_text(text: str, token_to_path: list[tuple[str, str]]) -> str:
"""Replace each ``@title`` token with a backtick-wrapped virtual path.
Mirrors ``parseMentionSegments`` on the frontend: longest token
first, single forward pass, no regex (titles can contain regex
metacharacters). The substitution is idempotent for already-
substituted text because the backtick-wrapped path no longer
starts with ``@``.
Empty / no-op cases short-circuit so callers can pass this through
unconditionally without paying for a scan.
"""
if not text or not token_to_path:
return text
out: list[str] = []
i = 0
n = len(text)
while i < n:
matched: tuple[str, str] | None = None
for token, path in token_to_path:
if text.startswith(token, i):
matched = (token, path)
break
if matched is None:
out.append(text[i])
i += 1
continue
token, path = matched
out.append(f"`{path}`")
i += len(token)
return "".join(out)
__all__ = [
"ResolvedMention",
"ResolvedMentionSet",
"resolve_mentions",
"substitute_in_text",
]