diff --git a/surfsense_backend/Dockerfile b/surfsense_backend/Dockerfile
index 4f24d2b05..1222b36b6 100644
--- a/surfsense_backend/Dockerfile
+++ b/surfsense_backend/Dockerfile
@@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp
ENV PYTHONPATH=/app
ENV UVICORN_LOOP=asyncio
+# Tune glibc malloc to return freed memory to the OS more aggressively.
+# Without these, Python's gc.collect() frees objects but the underlying
+# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation.
+ENV MALLOC_MMAP_THRESHOLD_=65536
+ENV MALLOC_TRIM_THRESHOLD_=131072
+ENV MALLOC_MMAP_MAX_=65536
+
# SERVICE_ROLE controls which process this container runs:
# api – FastAPI backend only (runs migrations on startup)
# worker – Celery worker only
diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py
index 3843b1687..af0d6bdc5 100644
--- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py
+++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py
@@ -28,8 +28,9 @@ from app.agents.new_chat.system_prompt import (
from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
+from app.utils.perf import get_perf_logger
-_perf_log = logging.getLogger("surfsense.perf")
+_perf_log = get_perf_logger()
# =============================================================================
# Connector Type Mapping
diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py
index bf16b2fe9..2b1c07cda 100644
--- a/surfsense_backend/app/agents/new_chat/llm_config.py
+++ b/surfsense_backend/app/agents/new_chat/llm_config.py
@@ -22,6 +22,7 @@ from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
+ get_auto_mode_llm,
is_auto_mode,
)
@@ -389,7 +390,7 @@ def create_chat_litellm_from_agent_config(
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
- return ChatLiteLLMRouter()
+ return get_auto_mode_llm()
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
diff --git a/surfsense_backend/app/agents/new_chat/sandbox.py b/surfsense_backend/app/agents/new_chat/sandbox.py
index 7696f67f2..8b634993b 100644
--- a/surfsense_backend/app/agents/new_chat/sandbox.py
+++ b/surfsense_backend/app/agents/new_chat/sandbox.py
@@ -58,6 +58,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
_daytona_client: Daytona | None = None
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
+_SANDBOX_CACHE_MAX_SIZE = 20
THREAD_LABEL_KEY = "surfsense_thread"
@@ -144,6 +145,12 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
return cached
sandbox = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
+
+ if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
+ oldest_key = next(iter(_sandbox_cache))
+ _sandbox_cache.pop(oldest_key, None)
+ logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
+
return sandbox
diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py
index 6989a1aa2..f1d3d16b8 100644
--- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py
+++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py
@@ -10,6 +10,8 @@ This module provides:
import asyncio
import json
+import re
+import time
from datetime import datetime
from typing import Any
@@ -17,8 +19,152 @@ from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
-from app.db import async_session_maker
+from app.db import shielded_async_session
from app.services.connector_service import ConnectorService
+from app.utils.perf import get_perf_logger
+
+# Connectors that call external live-search APIs (no local DB / embedding needed).
+# These are never filtered by available_document_types.
+_LIVE_SEARCH_CONNECTORS: set[str] = {
+ "TAVILY_API",
+ "SEARXNG_API",
+ "LINKUP_API",
+ "BAIDU_SEARCH_API",
+}
+
+# Patterns that indicate the query has no meaningful search signal.
+# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
+# of '*' is random noise, so both keyword and semantic search degrade to
+# arbitrary ordering — large documents (many chunks) dominate by chance.
+_DEGENERATE_QUERY_RE = re.compile(
+ r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
+)
+
+# Max chunks per document when doing a recency-based browse instead of
+# a real search. We want breadth (many docs) over depth (many chunks).
+_BROWSE_MAX_CHUNKS_PER_DOC = 5
+
+
+def _is_degenerate_query(query: str) -> bool:
+ """Return True when the query carries no meaningful search signal.
+
+ Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
+ strings, and single-character non-word tokens. These queries cause
+ both keyword search (empty tsquery) and semantic search (meaningless
+ embedding) to return effectively random results.
+ """
+ stripped = query.strip()
+ if not stripped:
+ return True
+ return bool(_DEGENERATE_QUERY_RE.match(stripped))
+
+
+async def _browse_recent_documents(
+ search_space_id: int,
+ document_type: str | None,
+ top_k: int,
+ start_date: datetime | None,
+ end_date: datetime | None,
+) -> list[dict[str, Any]]:
+ """Return the most-recent documents (recency-ordered, no search ranking).
+
+ Used as a fallback when the search query is degenerate (e.g. ``*``) and
+ semantic / keyword search would produce arbitrary results. Returns
+ document-grouped dicts in the same shape as ``_combined_rrf_search``
+ so the rest of the pipeline works unchanged.
+ """
+ from sqlalchemy import select
+ from sqlalchemy.orm import joinedload
+
+ from app.db import Chunk, Document, DocumentType
+
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
+ base_conditions = [Document.search_space_id == search_space_id]
+
+ if document_type is not None:
+ if isinstance(document_type, str):
+ try:
+ doc_type_enum = DocumentType[document_type]
+ base_conditions.append(Document.document_type == doc_type_enum)
+ except KeyError:
+ return []
+ else:
+ base_conditions.append(Document.document_type == document_type)
+
+ if start_date is not None:
+ base_conditions.append(Document.updated_at >= start_date)
+ if end_date is not None:
+ base_conditions.append(Document.updated_at <= end_date)
+
+ async with shielded_async_session() as session:
+ doc_query = (
+ select(Document)
+ .options(joinedload(Document.search_space))
+ .where(*base_conditions)
+ .order_by(Document.updated_at.desc())
+ .limit(top_k)
+ )
+ result = await session.execute(doc_query)
+ documents = result.scalars().unique().all()
+
+ if not documents:
+ return []
+
+ doc_ids = [d.id for d in documents]
+
+ chunk_query = (
+ select(Chunk)
+ .where(Chunk.document_id.in_(doc_ids))
+ .order_by(Chunk.document_id, Chunk.id)
+ )
+ chunk_result = await session.execute(chunk_query)
+ raw_chunks = chunk_result.scalars().all()
+
+ doc_chunk_counts: dict[int, int] = {}
+ doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
+ for chunk in raw_chunks:
+ did = chunk.document_id
+ count = doc_chunk_counts.get(did, 0)
+ if count < _BROWSE_MAX_CHUNKS_PER_DOC:
+ doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
+ doc_chunk_counts[did] = count + 1
+
+ results: list[dict[str, Any]] = []
+ for doc in documents:
+ chunks_list = doc_chunks.get(doc.id, [])
+ results.append(
+ {
+ "document_id": doc.id,
+ "content": "\n\n".join(
+ c["content"] for c in chunks_list if c.get("content")
+ ),
+ "score": 0.0,
+ "chunks": chunks_list,
+ "document": {
+ "id": doc.id,
+ "title": doc.title,
+ "document_type": doc.document_type.value
+ if getattr(doc, "document_type", None)
+ else None,
+ "metadata": doc.document_metadata or {},
+ },
+ "source": doc.document_type.value
+ if getattr(doc, "document_type", None)
+ else None,
+ }
+ )
+
+ perf.info(
+ "[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
+ time.perf_counter() - t0,
+ len(results),
+ search_space_id,
+ document_type,
+ )
+ return results
+
# =============================================================================
# Connector Constants and Normalization
@@ -182,9 +328,23 @@ _CHARS_PER_TOKEN = 4
# Hard-floor / ceiling so the budget is always sensible regardless of what
# the model reports.
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
-_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens
+_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
_MAX_CHUNK_CHARS = 8_000
+# Rank-adaptive per-document budget allocation.
+# Top-ranked (most relevant) documents get a larger share of the budget so
+# we pack as much high-quality context as possible.
+#
+# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
+#
+# Examples (128K budget, 8K chunk cap):
+# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
+# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
+# rank 2 → 24% → 3 chunks |
+_TOP_DOC_BUDGET_FRACTION = 0.40
+_RANK_DECAY = 0.35
+_MIN_CHUNKS_PER_DOC = 3
+
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
"""Derive a character budget from the model's context window.
@@ -206,18 +366,24 @@ def format_documents_for_context(
*,
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
max_chunk_chars: int = _MAX_CHUNK_CHARS,
+ max_chunks_per_doc: int = 0,
) -> str:
"""
Format retrieved documents into a readable context string for the LLM.
Documents are added in order (highest relevance first) until the character
- budget is reached. Individual chunks are capped at ``max_chunk_chars`` so
- a single oversized chunk cannot monopolize the output.
+ budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
+ each document is limited to a dynamically computed chunk cap so a single
+ large document cannot monopolize the output while still maximising the use
+ of available context space.
Args:
documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
+ max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
+ auto-compute per document using a rank-adaptive formula so
+ higher-ranked documents receive more chunks.
Returns:
Formatted string with document contents and metadata
@@ -340,7 +506,23 @@ def format_documents_for_context(
"",
]
- for ch in g["chunks"]:
+ # Rank-adaptive per-document chunk cap: top results get more chunks.
+ if max_chunks_per_doc > 0:
+ chunks_allowed = max_chunks_per_doc
+ else:
+ doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
+ max_doc_chars = int(max_chars * doc_fraction)
+ xml_overhead = 500
+ chunks_allowed = max(
+ (max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
+ _MIN_CHUNKS_PER_DOC,
+ )
+
+ chunks = g["chunks"]
+ if len(chunks) > chunks_allowed:
+ chunks = chunks[:chunks_allowed]
+
+ for ch in chunks:
ch_content = ch["content"]
if max_chunk_chars and len(ch_content) > max_chunk_chars:
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
@@ -357,9 +539,11 @@ def format_documents_for_context(
doc_xml = "\n".join(doc_lines)
doc_len = len(doc_xml)
- # Always include at least the first document; afterwards enforce budget.
- if doc_idx > 0 and total_chars + doc_len > max_chars:
+ if total_chars + doc_len > max_chars:
remaining = total_docs - doc_idx
+ if doc_idx == 0:
+ parts.append(doc_xml)
+ total_chars += doc_len
parts.append(
f""
+ result = result[: max_chars - len(truncation_msg)] + truncation_msg
+
+ return result
# =============================================================================
@@ -388,6 +580,7 @@ async def search_knowledge_base_async(
start_date: datetime | None = None,
end_date: datetime | None = None,
available_connectors: list[str] | None = None,
+ available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> str:
"""
@@ -406,12 +599,18 @@ async def search_knowledge_base_async(
end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched.
+ available_document_types: Optional list of document types that actually have indexed
+ data. When provided, local connectors whose document type is
+ absent are skipped entirely (no embedding / DB round-trip).
max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits.
Returns:
Formatted string with search results
"""
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
all_documents: list[dict[str, Any]] = []
# Resolve date range (default last 2 years)
@@ -424,88 +623,169 @@ async def search_knowledge_base_async(
connectors = _normalize_connectors(connectors_to_search, available_connectors)
- connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
- "YOUTUBE_VIDEO": ("search_youtube", True, True, {}),
- "EXTENSION": ("search_extension", True, True, {}),
- "CRAWLED_URL": ("search_crawled_urls", True, True, {}),
- "FILE": ("search_files", True, True, {}),
- "SLACK_CONNECTOR": ("search_slack", True, True, {}),
- "TEAMS_CONNECTOR": ("search_teams", True, True, {}),
- "NOTION_CONNECTOR": ("search_notion", True, True, {}),
- "GITHUB_CONNECTOR": ("search_github", True, True, {}),
- "LINEAR_CONNECTOR": ("search_linear", True, True, {}),
+ # --- Optimization 1: skip local connectors that have zero indexed documents ---
+ if available_document_types:
+ doc_types_set = set(available_document_types)
+ before_count = len(connectors)
+ connectors = [
+ c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set
+ ]
+ skipped = before_count - len(connectors)
+ if skipped:
+ perf.info(
+ "[kb_search] skipped %d empty connectors (had %d, now %d)",
+ skipped,
+ before_count,
+ len(connectors),
+ )
+
+ perf.info(
+ "[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
+ len(connectors),
+ connectors[:5],
+ search_space_id,
+ top_k,
+ )
+
+ # --- Fast-path: degenerate queries (*, **, empty, etc.) ---
+ # Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
+ # yields an empty tsquery, so both retrieval signals are useless.
+ # Fall back to a recency-ordered browse that returns diverse results.
+ if _is_degenerate_query(query):
+ perf.info(
+ "[kb_search] degenerate query %r detected - falling back to recency browse",
+ query,
+ )
+ local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS]
+ if not local_connectors:
+ local_connectors = [None] # type: ignore[list-item]
+
+ browse_results = await asyncio.gather(
+ *[
+ _browse_recent_documents(
+ search_space_id=search_space_id,
+ document_type=c,
+ top_k=top_k,
+ start_date=resolved_start_date,
+ end_date=resolved_end_date,
+ )
+ for c in local_connectors
+ ]
+ )
+ for docs in browse_results:
+ all_documents.extend(docs)
+
+ # Skip dedup + formatting below (browse already returns unique docs)
+ # but still cap output budget.
+ output_budget = _compute_tool_output_budget(max_input_tokens)
+ result = format_documents_for_context(
+ all_documents,
+ max_chars=output_budget,
+ max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
+ )
+ perf.info(
+ "[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
+ "budget=%d space=%d",
+ time.perf_counter() - t0,
+ len(all_documents),
+ len(result),
+ output_budget,
+ search_space_id,
+ )
+ return result
+
+ # Specs for live-search connectors (external APIs, no local DB/embedding).
+ live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"TAVILY_API": ("search_tavily", False, True, {}),
"SEARXNG_API": ("search_searxng", False, True, {}),
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
- "DISCORD_CONNECTOR": ("search_discord", True, True, {}),
- "JIRA_CONNECTOR": ("search_jira", True, True, {}),
- "GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
- "AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
- "GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
- "GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
- "CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
- "CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
- "LUMA_CONNECTOR": ("search_luma", True, True, {}),
- "ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
- "NOTE": ("search_notes", True, True, {}),
- "BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
- "CIRCLEBACK": ("search_circleback", True, True, {}),
- "OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
- # Composio connectors
- "COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
- "search_composio_google_drive",
- True,
- True,
- {},
- ),
- "COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
- "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
- "search_composio_google_calendar",
- True,
- True,
- {},
- ),
}
- # Keep a conservative cap to avoid overloading DB/external services.
+ # --- Optimization 2: compute the query embedding once, share across all local searches ---
+ precomputed_embedding: list[float] | None = None
+ has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors)
+ if has_local_connectors:
+ from app.config import config as app_config
+
+ t_embed = time.perf_counter()
+ precomputed_embedding = app_config.embedding_model_instance.embed(query)
+ perf.info(
+ "[kb_search] shared embedding computed in %.3fs",
+ time.perf_counter() - t_embed,
+ )
+
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
- spec = connector_specs.get(connector)
- if spec is None:
- return []
+ is_live = connector in _LIVE_SEARCH_CONNECTORS
- method_name, includes_date_range, includes_top_k, extra_kwargs = spec
- kwargs: dict[str, Any] = {
- "user_query": query,
- "search_space_id": search_space_id,
- **extra_kwargs,
- }
- if includes_top_k:
- kwargs["top_k"] = top_k
- if includes_date_range:
- kwargs["start_date"] = resolved_start_date
- kwargs["end_date"] = resolved_end_date
+ if is_live:
+ spec = live_connector_specs.get(connector)
+ if spec is None:
+ return []
+ method_name, includes_date_range, includes_top_k, extra_kwargs = spec
+ kwargs: dict[str, Any] = {
+ "user_query": query,
+ "search_space_id": search_space_id,
+ **extra_kwargs,
+ }
+ if includes_top_k:
+ kwargs["top_k"] = top_k
+ if includes_date_range:
+ kwargs["start_date"] = resolved_start_date
+ kwargs["end_date"] = resolved_end_date
+ try:
+ t_conn = time.perf_counter()
+ async with semaphore, shielded_async_session() as isolated_session:
+ svc = ConnectorService(isolated_session, search_space_id)
+ _, chunks = await getattr(svc, method_name)(**kwargs)
+ perf.info(
+ "[kb_search] connector=%s results=%d in %.3fs",
+ connector,
+ len(chunks),
+ time.perf_counter() - t_conn,
+ )
+ return chunks
+ except Exception as e:
+ perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
+ return []
+
+ # --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
try:
- # Use isolated session per connector. Shared AsyncSession cannot safely
- # run concurrent DB operations.
- async with semaphore, async_session_maker() as isolated_session:
- isolated_connector_service = ConnectorService(
- isolated_session, search_space_id
+ t_conn = time.perf_counter()
+ async with semaphore, shielded_async_session() as isolated_session:
+ svc = ConnectorService(isolated_session, search_space_id)
+ chunks = await svc._combined_rrf_search(
+ query_text=query,
+ search_space_id=search_space_id,
+ document_type=connector,
+ top_k=top_k,
+ start_date=resolved_start_date,
+ end_date=resolved_end_date,
+ query_embedding=precomputed_embedding,
+ )
+ perf.info(
+ "[kb_search] connector=%s results=%d in %.3fs",
+ connector,
+ len(chunks),
+ time.perf_counter() - t_conn,
)
- connector_method = getattr(isolated_connector_service, method_name)
- _, chunks = await connector_method(**kwargs)
return chunks
except Exception as e:
- print(f"Error searching connector {connector}: {e}")
+ perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
return []
+ t_gather = time.perf_counter()
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
+ perf.info(
+ "[kb_search] all connectors gathered in %.3fs",
+ time.perf_counter() - t_gather,
+ )
for chunks in connector_results:
all_documents.extend(chunks)
@@ -552,7 +832,28 @@ async def search_knowledge_base_async(
deduplicated.append(doc)
output_budget = _compute_tool_output_budget(max_input_tokens)
- return format_documents_for_context(deduplicated, max_chars=output_budget)
+ result = format_documents_for_context(deduplicated, max_chars=output_budget)
+
+ if len(result) > output_budget:
+ perf.warning(
+ "[kb_search] output STILL exceeds budget after format (%d > %d), "
+ "hard truncation should have fired",
+ len(result),
+ output_budget,
+ )
+
+ perf.info(
+ "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
+ "budget=%d max_input_tokens=%s space=%d",
+ time.perf_counter() - t0,
+ len(all_documents),
+ len(deduplicated),
+ len(result),
+ output_budget,
+ max_input_tokens,
+ search_space_id,
+ )
+ return result
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
@@ -590,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel):
"""Input schema for the search_knowledge_base tool."""
query: str = Field(
- description="The search query - be specific and include key terms"
+ description=(
+ "The search query - use specific natural language terms. "
+ "NEVER use wildcards like '*' or '**'; instead describe what you want "
+ "(e.g. 'recent meeting notes' or 'project architecture overview')."
+ ),
)
top_k: int = Field(
default=10,
- description="Number of results to retrieve (default: 10)",
+ description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
)
start_date: str | None = Field(
default=None,
@@ -657,6 +962,10 @@ Focus searches on these types for best results."""
Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question.
IMPORTANT:
+- Always craft specific, descriptive search queries using natural language keywords.
+ Good: "quarterly sales report Q3", "Python API authentication design".
+ Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
+- Prefer multiple focused searches over a single broad one with high top_k.
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
- If `connectors_to_search` is omitted/empty, the system will search broadly.
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
@@ -672,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
# Capture for closure
_available_connectors = available_connectors
+ _available_document_types = available_document_types
async def _search_knowledge_base_impl(
query: str,
@@ -701,6 +1011,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
start_date=parsed_start,
end_date=parsed_end,
available_connectors=_available_connectors,
+ available_document_types=_available_document_types,
max_input_tokens=max_input_tokens,
)
diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
index 20cf3ec33..2fb7ffb06 100644
--- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
+++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
@@ -27,9 +27,24 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
+_MCP_CACHE_MAX_SIZE = 50
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
+def _evict_expired_mcp_cache() -> None:
+ """Remove expired entries from the MCP tools cache to prevent unbounded growth."""
+ now = time.monotonic()
+ expired = [
+ k
+ for k, (ts, _) in _mcp_tools_cache.items()
+ if now - ts >= _MCP_CACHE_TTL_SECONDS
+ ]
+ for k in expired:
+ del _mcp_tools_cache[k]
+ if expired:
+ logger.debug("Evicted %d expired MCP cache entries", len(expired))
+
+
def _create_dynamic_input_model_from_schema(
tool_name: str,
input_schema: dict[str, Any],
@@ -392,6 +407,8 @@ async def load_mcp_tools(
List of LangChain StructuredTool instances
"""
+ _evict_expired_mcp_cache()
+
now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id)
if cached is not None:
@@ -445,6 +462,11 @@ async def load_mcp_tools(
)
_mcp_tools_cache[search_space_id] = (now, tools)
+
+ if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
+ oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
+ del _mcp_tools_cache[oldest_key]
+
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools
diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py
index f36f0de13..99cb09b38 100644
--- a/surfsense_backend/app/agents/new_chat/tools/registry.py
+++ b/surfsense_backend/app/agents/new_chat/tools/registry.py
@@ -145,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
thread_id=deps["thread_id"],
connector_service=deps.get("connector_service"),
available_connectors=deps.get("available_connectors"),
+ available_document_types=deps.get("available_document_types"),
),
requires=["search_space_id", "thread_id"],
- # connector_service and available_connectors are optional —
- # when missing, source_strategy="kb_search" degrades gracefully to "provided"
+ # connector_service, available_connectors, and available_document_types
+ # are optional — when missing, source_strategy="kb_search" degrades
+ # gracefully to "provided"
),
# Link preview tool - fetches Open Graph metadata for URLs
ToolDefinition(
diff --git a/surfsense_backend/app/agents/new_chat/tools/report.py b/surfsense_backend/app/agents/new_chat/tools/report.py
index 0896fea4b..fe5181f54 100644
--- a/surfsense_backend/app/agents/new_chat/tools/report.py
+++ b/surfsense_backend/app/agents/new_chat/tools/report.py
@@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
-from app.db import Report, async_session_maker
+from app.db import Report, shielded_async_session
from app.services.connector_service import ConnectorService
from app.services.llm_service import get_document_summary_llm
@@ -559,6 +559,7 @@ def create_generate_report_tool(
thread_id: int | None = None,
connector_service: ConnectorService | None = None,
available_connectors: list[str] | None = None,
+ available_document_types: list[str] | None = None,
):
"""
Factory function to create the generate_report tool with injected dependencies.
@@ -716,7 +717,7 @@ def create_generate_report_tool(
async def _save_failed_report(error_msg: str) -> int | None:
"""Persist a failed report row using a short-lived session."""
try:
- async with async_session_maker() as session:
+ async with shielded_async_session() as session:
failed_report = Report(
title=topic,
content=None,
@@ -750,7 +751,7 @@ def create_generate_report_tool(
# ── Phase 1: READ (short-lived session) ──────────────────────
# Fetch parent report and LLM config, then close the session
# so no DB connection is held during the long LLM call.
- async with async_session_maker() as read_session:
+ async with shielded_async_session() as read_session:
if parent_report_id:
parent_report = await read_session.get(Report, parent_report_id)
if parent_report:
@@ -827,7 +828,7 @@ def create_generate_report_tool(
# Run all queries in parallel, each with its own session
async def _run_single_query(q: str) -> str:
- async with async_session_maker() as kb_session:
+ async with shielded_async_session() as kb_session:
kb_connector_svc = ConnectorService(
kb_session, search_space_id
)
@@ -838,6 +839,7 @@ def create_generate_report_tool(
connector_service=kb_connector_svc,
top_k=10,
available_connectors=available_connectors,
+ available_document_types=available_document_types,
)
kb_results = await asyncio.gather(
@@ -1026,7 +1028,7 @@ def create_generate_report_tool(
# ── Phase 3: WRITE (short-lived session) ─────────────────────
# Save the report to the database, then close the session.
- async with async_session_maker() as write_session:
+ async with shielded_async_session() as write_session:
report = Report(
title=topic,
content=report_content,
diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py
index 0a549abe5..e6db5670e 100644
--- a/surfsense_backend/app/app.py
+++ b/surfsense_backend/app/app.py
@@ -1,4 +1,5 @@
import asyncio
+import gc
import logging
import time
from collections import defaultdict
@@ -15,6 +16,9 @@ from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
+from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+from starlette.requests import Request as StarletteRequest
+from starlette.responses import Response as StarletteResponse
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from app.agents.new_chat.checkpointer import (
@@ -28,6 +32,7 @@ from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
+from app.utils.perf import get_perf_logger, log_system_snapshot
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
@@ -99,22 +104,24 @@ def _check_rate_limit_memory(
now = time.monotonic()
with _memory_lock:
- # Evict timestamps outside the current window
- _memory_rate_limits[key] = [
- t for t in _memory_rate_limits[key] if now - t < window_seconds
- ]
+ timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]
- if len(_memory_rate_limits[key]) >= max_requests:
+ if not timestamps:
+ _memory_rate_limits.pop(key, None)
+ else:
+ _memory_rate_limits[key] = timestamps
+
+ if len(timestamps) >= max_requests:
rate_limit_logger.warning(
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
- f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)"
+ f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
)
raise HTTPException(
status_code=429,
detail="RATE_LIMIT_EXCEEDED",
)
- _memory_rate_limits[key].append(now)
+ _memory_rate_limits[key] = [*timestamps, now]
def _check_rate_limit(
@@ -206,18 +213,16 @@ def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
- # Enable slow-callback detection (set PERF_DEBUG=1 env var to activate)
+ # Tune GC: lower gen-2 threshold so long-lived garbage is collected
+ # sooner (default 700/10/10 → 700/10/5). This reduces peak RSS
+ # with minimal CPU overhead.
+ gc.set_threshold(700, 10, 5)
+
_enable_slow_callback_logging(threshold_sec=0.5)
- # Not needed if you setup a migration system like Alembic
await create_db_and_tables()
- # Setup LangGraph checkpointer tables for conversation persistence
await setup_checkpointer_tables()
- # Initialize LLM Router for Auto mode load balancing
initialize_llm_router()
- # Initialize Image Generation Router for Auto mode load balancing
initialize_image_gen_router()
- # Seed Surfsense documentation (with timeout so a slow embedding API
- # doesn't block startup indefinitely and make the container unresponsive)
try:
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
except TimeoutError:
@@ -225,8 +230,11 @@ async def lifespan(app: FastAPI):
"Surfsense docs seeding timed out after 120s — skipping. "
"Docs will be indexed on the next restart."
)
+
+ log_system_snapshot("startup_complete")
+
yield
- # Cleanup: close checkpointer connection on shutdown
+
await close_checkpointer()
@@ -244,6 +252,63 @@ app = FastAPI(lifespan=lifespan)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+
+# ---------------------------------------------------------------------------
+# Request-level performance middleware
+# ---------------------------------------------------------------------------
+# Logs wall-clock time, method, path, and status for every request so we can
+# spot slow endpoints in production logs.
+
+_PERF_SLOW_REQUEST_THRESHOLD = float(
+ __import__("os").environ.get("PERF_SLOW_REQUEST_MS", "2000")
+)
+
+
+class RequestPerfMiddleware(BaseHTTPMiddleware):
+ """Middleware that logs per-request wall-clock time.
+
+ - ALL requests are logged at DEBUG level.
+ - Requests exceeding PERF_SLOW_REQUEST_MS (default 2000ms) are logged at
+ WARNING level with a system snapshot so we can correlate slow responses
+ with CPU/memory usage at that moment.
+ """
+
+ async def dispatch(
+ self, request: StarletteRequest, call_next: RequestResponseEndpoint
+ ) -> StarletteResponse:
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+ response = await call_next(request)
+ elapsed_ms = (time.perf_counter() - t0) * 1000
+
+ path = request.url.path
+ method = request.method
+ status = response.status_code
+
+ perf.debug(
+ "[request] %s %s -> %d in %.1fms",
+ method,
+ path,
+ status,
+ elapsed_ms,
+ )
+
+ if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD:
+ perf.warning(
+ "[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)",
+ method,
+ path,
+ status,
+ elapsed_ms,
+ _PERF_SLOW_REQUEST_THRESHOLD,
+ )
+ log_system_snapshot("slow_request")
+
+ return response
+
+
+app.add_middleware(RequestPerfMiddleware)
+
# Add SlowAPI middleware for automatic rate limiting
# Uses Starlette BaseHTTPMiddleware (not the raw ASGI variant) to avoid
# corrupting StreamingResponse — SlowAPIASGIMiddleware re-sends
diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py
index 771689a13..510f64cc3 100644
--- a/surfsense_backend/app/db.py
+++ b/surfsense_backend/app/db.py
@@ -1,7 +1,9 @@
from collections.abc import AsyncGenerator
+from contextlib import asynccontextmanager
from datetime import UTC, datetime
from enum import StrEnum
+import anyio
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
from pgvector.sqlalchemy import Vector
@@ -1856,10 +1858,37 @@ class RefreshToken(Base, TimestampMixin):
return not self.is_expired and not self.is_revoked
-engine = create_async_engine(DATABASE_URL)
+engine = create_async_engine(
+ DATABASE_URL,
+ pool_size=30,
+ max_overflow=150,
+ pool_recycle=1800,
+ pool_pre_ping=True,
+ pool_timeout=30,
+)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
+@asynccontextmanager
+async def shielded_async_session():
+ """Cancellation-safe async session context manager.
+
+ Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel
+ scope when a client disconnects. A plain ``async with async_session_maker()``
+ has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the
+ scope, orphaning the underlying database connection.
+
+ This wrapper ensures ``session.close()`` always completes by running it
+ inside ``anyio.CancelScope(shield=True)``.
+ """
+ session = async_session_maker()
+ try:
+ yield session
+ finally:
+ with anyio.CancelScope(shield=True):
+ await session.close()
+
+
async def setup_indexes():
async with engine.begin() as conn:
# Create indexes
diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py
index eea3d6e25..9460f900c 100644
--- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py
+++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py
@@ -1,4 +1,5 @@
import contextlib
+import time
from datetime import UTC, datetime
from sqlalchemy import delete, select
@@ -44,6 +45,7 @@ from app.indexing_pipeline.pipeline_logger import (
log_retryable_llm_error,
log_unexpected_error,
)
+from app.utils.perf import get_perf_logger
class IndexingPipelineService:
@@ -58,6 +60,9 @@ class IndexingPipelineService:
"""
Persist new documents and detect changes, returning only those that need indexing.
"""
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
documents = []
seen_hashes: set[str] = set()
batch_ctx = PipelineLogContext(
@@ -140,11 +145,14 @@ class IndexingPipelineService:
try:
await self.session.commit()
+ perf.info(
+ "[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
+ time.perf_counter() - t0,
+ len(connector_docs),
+ len(documents),
+ )
return documents
except IntegrityError:
- # A concurrent worker committed a document with the same content_hash
- # or unique_identifier_hash between our check and our INSERT.
- # The document already exists — roll back and let the next sync run handle it.
log_race_condition(batch_ctx)
await self.session.rollback()
return []
@@ -165,26 +173,41 @@ class IndexingPipelineService:
unique_id=connector_doc.unique_id,
doc_id=document.id,
)
+ perf = get_perf_logger()
+ t_index = time.perf_counter()
try:
log_index_started(ctx)
document.status = DocumentStatus.processing()
await self.session.commit()
+ t_step = time.perf_counter()
if connector_doc.should_summarize and llm is not None:
content = await summarize_document(
connector_doc.source_markdown, llm, connector_doc.metadata
)
+ perf.info(
+ "[indexing] summarize_document doc=%d in %.3fs",
+ document.id,
+ time.perf_counter() - t_step,
+ )
elif connector_doc.should_summarize and connector_doc.fallback_summary:
content = connector_doc.fallback_summary
else:
content = connector_doc.source_markdown
+ t_step = time.perf_counter()
embedding = embed_text(content)
+ perf.debug(
+ "[indexing] embed_text (summary) doc=%d in %.3fs",
+ document.id,
+ time.perf_counter() - t_step,
+ )
await self.session.execute(
delete(Chunk).where(Chunk.document_id == document.id)
)
+ t_step = time.perf_counter()
chunks = [
Chunk(content=text, embedding=embed_text(text))
for text in chunk_text(
@@ -192,6 +215,12 @@ class IndexingPipelineService:
use_code_chunker=connector_doc.should_use_code_chunker,
)
]
+ perf.info(
+ "[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
+ document.id,
+ len(chunks),
+ time.perf_counter() - t_step,
+ )
document.content = content
document.embedding = embedding
@@ -199,6 +228,12 @@ class IndexingPipelineService:
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.ready()
await self.session.commit()
+ perf.info(
+ "[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
+ document.id,
+ len(chunks),
+ time.perf_counter() - t_index,
+ )
log_index_success(ctx, chunk_count=len(chunks))
except RETRYABLE_LLM_ERRORS as e:
diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py
index 9aa301386..4787e8147 100644
--- a/surfsense_backend/app/retriever/chunks_hybrid_search.py
+++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py
@@ -1,5 +1,10 @@
+import time
from datetime import datetime
+from app.utils.perf import get_perf_logger
+
+_MAX_FETCH_CHUNKS_PER_DOC = 30
+
class ChucksHybridSearchRetriever:
def __init__(self, db_session):
@@ -38,9 +43,17 @@ class ChucksHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
# Get embedding for the query
embedding_model = config.embedding_model_instance
+ t_embed = time.perf_counter()
query_embedding = embedding_model.embed(query_text)
+ perf.debug(
+ "[chunk_search] vector_search embedding in %.3fs",
+ time.perf_counter() - t_embed,
+ )
# Build the query filtered by search space
query = (
@@ -60,8 +73,16 @@ class ChucksHybridSearchRetriever:
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
# Execute the query
+ t_db = time.perf_counter()
result = await self.db_session.execute(query)
chunks = result.scalars().all()
+ perf.info(
+ "[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d",
+ time.perf_counter() - t_db,
+ len(chunks),
+ time.perf_counter() - t0,
+ search_space_id,
+ )
return chunks
@@ -91,6 +112,9 @@ class ChucksHybridSearchRetriever:
from app.db import Chunk, Document
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector("english", Chunk.content)
tsquery = func.plainto_tsquery("english", query_text)
@@ -118,6 +142,12 @@ class ChucksHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
chunks = result.scalars().all()
+ perf.info(
+ "[chunk_search] full_text_search in %.3fs results=%d space=%d",
+ time.perf_counter() - t0,
+ len(chunks),
+ search_space_id,
+ )
return chunks
@@ -129,6 +159,7 @@ class ChucksHybridSearchRetriever:
document_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
+ query_embedding: list | None = None,
) -> list:
"""
Hybrid search that returns **documents** (not individual chunks).
@@ -143,6 +174,7 @@ class ChucksHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
+ query_embedding: Pre-computed embedding vector. If None, will be computed here.
Returns:
List of dictionaries containing document data and relevance scores. Each dict contains:
@@ -157,9 +189,17 @@ class ChucksHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document, DocumentType
- # Get embedding for the query
- embedding_model = config.embedding_model_instance
- query_embedding = embedding_model.embed(query_text)
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
+ if query_embedding is None:
+ embedding_model = config.embedding_model_instance
+ t_embed = time.perf_counter()
+ query_embedding = embedding_model.embed(query_text)
+ perf.debug(
+ "[chunk_search] hybrid_search embedding in %.3fs",
+ time.perf_counter() - t_embed,
+ )
# RRF constants
k = 60
@@ -254,9 +294,17 @@ class ChucksHybridSearchRetriever:
.limit(top_k)
)
- # Execute the query
+ # Execute the RRF query
+ t_rrf = time.perf_counter()
result = await self.db_session.execute(final_query)
chunks_with_scores = result.all()
+ perf.info(
+ "[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s",
+ time.perf_counter() - t_rrf,
+ len(chunks_with_scores),
+ search_space_id,
+ document_type,
+ )
# If no results were found, return an empty list
if not chunks_with_scores:
@@ -300,8 +348,9 @@ class ChucksHybridSearchRetriever:
if not doc_ids:
return []
- # Fetch ALL chunks for selected documents in a single query so the final prompt can cite
- # any chunk from those documents.
+ # Fetch chunks for selected documents. We cap per document to avoid
+ # loading hundreds of chunks for a single large file while still
+ # ensuring the chunks that matched the RRF query are always included.
chunk_query = (
select(Chunk)
.options(joinedload(Chunk.document))
@@ -311,7 +360,20 @@ class ChucksHybridSearchRetriever:
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunk_query)
- all_chunks = chunks_result.scalars().all()
+ raw_chunks = chunks_result.scalars().all()
+
+ matched_chunk_ids: set[int] = {
+ item["chunk_id"] for item in serialized_chunk_results
+ }
+
+ doc_chunk_counts: dict[int, int] = {}
+ all_chunks: list = []
+ for chunk in raw_chunks:
+ did = chunk.document_id
+ count = doc_chunk_counts.get(did, 0)
+ if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
+ all_chunks.append(chunk)
+ doc_chunk_counts[did] = count + 1
# Assemble final doc-grouped results in the same order as doc_ids
doc_map: dict[int, dict] = {
@@ -354,4 +416,11 @@ class ChucksHybridSearchRetriever:
)
final_docs.append(entry)
+ perf.info(
+ "[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
+ time.perf_counter() - t0,
+ len(final_docs),
+ search_space_id,
+ document_type,
+ )
return final_docs
diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py
index 9ff104ff0..69e97384f 100644
--- a/surfsense_backend/app/retriever/documents_hybrid_search.py
+++ b/surfsense_backend/app/retriever/documents_hybrid_search.py
@@ -1,5 +1,10 @@
+import time
from datetime import datetime
+from app.utils.perf import get_perf_logger
+
+_MAX_FETCH_CHUNKS_PER_DOC = 30
+
class DocumentHybridSearchRetriever:
def __init__(self, db_session):
@@ -38,6 +43,9 @@ class DocumentHybridSearchRetriever:
from app.config import config
from app.db import Document
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
@@ -63,6 +71,12 @@ class DocumentHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
documents = result.scalars().all()
+ perf.info(
+ "[doc_search] vector_search in %.3fs results=%d space=%d",
+ time.perf_counter() - t0,
+ len(documents),
+ search_space_id,
+ )
return documents
@@ -92,6 +106,9 @@ class DocumentHybridSearchRetriever:
from app.db import Document
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector("english", Document.content)
tsquery = func.plainto_tsquery("english", query_text)
@@ -118,6 +135,12 @@ class DocumentHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
documents = result.scalars().all()
+ perf.info(
+ "[doc_search] full_text_search in %.3fs results=%d space=%d",
+ time.perf_counter() - t0,
+ len(documents),
+ search_space_id,
+ )
return documents
@@ -129,6 +152,7 @@ class DocumentHybridSearchRetriever:
document_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
+ query_embedding: list | None = None,
) -> list:
"""
Hybrid search that returns **documents** (not individual chunks).
@@ -143,7 +167,7 @@ class DocumentHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
-
+ query_embedding: Pre-computed embedding vector. If None, will be computed here.
"""
from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload
@@ -151,9 +175,12 @@ class DocumentHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document, DocumentType
- # Get embedding for the query
- embedding_model = config.embedding_model_instance
- query_embedding = embedding_model.embed(query_text)
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
+ if query_embedding is None:
+ embedding_model = config.embedding_model_instance
+ query_embedding = embedding_model.embed(query_text)
# RRF constants
k = 60
@@ -254,7 +281,8 @@ class DocumentHybridSearchRetriever:
# Collect document IDs for chunk fetching
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
- # Fetch ALL chunks for these documents in a single query
+ # Fetch chunks for these documents, capped per document to avoid
+ # loading hundreds of chunks for a single large file.
chunks_query = (
select(Chunk)
.options(joinedload(Chunk.document))
@@ -262,7 +290,16 @@ class DocumentHybridSearchRetriever:
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunks_query)
- chunks = chunks_result.scalars().all()
+ raw_chunks = chunks_result.scalars().all()
+
+ doc_chunk_counts: dict[int, int] = {}
+ chunks: list = []
+ for chunk in raw_chunks:
+ did = chunk.document_id
+ count = doc_chunk_counts.get(did, 0)
+ if count < _MAX_FETCH_CHUNKS_PER_DOC:
+ chunks.append(chunk)
+ doc_chunk_counts[did] = count + 1
# Assemble doc-grouped results
doc_map: dict[int, dict] = {
@@ -303,4 +340,11 @@ class DocumentHybridSearchRetriever:
)
final_docs.append(entry)
+ perf.info(
+ "[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
+ time.perf_counter() - t0,
+ len(final_docs),
+ search_space_id,
+ document_type,
+ )
return final_docs
diff --git a/surfsense_backend/app/routes/chat_comments_routes.py b/surfsense_backend/app/routes/chat_comments_routes.py
index 1c21c0f4a..f5a8fd0af 100644
--- a/surfsense_backend/app/routes/chat_comments_routes.py
+++ b/surfsense_backend/app/routes/chat_comments_routes.py
@@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.schemas.chat_comments import (
+ CommentBatchRequest,
+ CommentBatchResponse,
CommentCreateRequest,
CommentListResponse,
CommentReplyResponse,
@@ -19,6 +21,7 @@ from app.services.chat_comments_service import (
create_reply,
delete_comment,
get_comments_for_message,
+ get_comments_for_messages_batch,
get_user_mentions,
update_comment,
)
@@ -27,6 +30,16 @@ from app.users import current_active_user
router = APIRouter()
+@router.post("/messages/comments/batch", response_model=CommentBatchResponse)
+async def batch_list_comments(
+ request: CommentBatchRequest,
+ session: AsyncSession = Depends(get_async_session),
+ user: User = Depends(current_active_user),
+):
+ """Batch-fetch comments for multiple messages in one request."""
+ return await get_comments_for_messages_batch(session, request.message_ids, user)
+
+
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
async def list_comments(
message_id: int,
diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py
index 1ce5082ca..865fdf7b3 100644
--- a/surfsense_backend/app/routes/documents_routes.py
+++ b/surfsense_backend/app/routes/documents_routes.py
@@ -133,6 +133,8 @@ async def create_documents_file_upload(
Requires DOCUMENTS_CREATE permission.
"""
+ import os
+ import tempfile
from datetime import datetime
from app.db import DocumentStatus
@@ -143,7 +145,6 @@ async def create_documents_file_upload(
from app.utils.document_converters import generate_unique_identifier_hash
try:
- # Check permission
await check_permission(
session,
user,
@@ -179,69 +180,64 @@ async def create_documents_file_upload(
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
- created_documents: list[Document] = []
- files_to_process: list[
- tuple[Document, str, str]
- ] = [] # (document, temp_path, filename)
- skipped_duplicates = 0
- duplicate_document_ids: list[int] = []
- actual_total_size = 0
+ # ===== Read all files concurrently to avoid blocking the event loop =====
+ async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
+ """Read upload content and write to temp file off the event loop."""
+ content = await file.read()
+ file_size = len(content)
+ filename = file.filename or "unknown"
- # ===== PHASE 1: Create pending documents for all files =====
- # This makes ALL documents visible in the UI immediately with pending status
- for file in files:
- try:
- import os
- import tempfile
-
- # Save file to temp location
- with tempfile.NamedTemporaryFile(
- delete=False, suffix=os.path.splitext(file.filename or "")[1]
- ) as temp_file:
- temp_path = temp_file.name
-
- content = await file.read()
- file_size = len(content)
-
- if file_size > MAX_FILE_SIZE_BYTES:
- os.unlink(temp_path)
- raise HTTPException(
- status_code=413,
- detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
- f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
- )
-
- actual_total_size += file_size
- if actual_total_size > MAX_TOTAL_SIZE_BYTES:
- os.unlink(temp_path)
- raise HTTPException(
- status_code=413,
- detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
- f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
- )
-
- with open(temp_path, "wb") as f:
- f.write(content)
-
- # Generate unique identifier for deduplication check
- unique_identifier_hash = generate_unique_identifier_hash(
- DocumentType.FILE, file.filename or "unknown", search_space_id
+ if file_size > MAX_FILE_SIZE_BYTES:
+ raise HTTPException(
+ status_code=413,
+ detail=f"File '{filename}' ({file_size / (1024 * 1024):.1f} MB) "
+ f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
+ )
+
+ def _write_temp() -> str:
+ with tempfile.NamedTemporaryFile(
+ delete=False, suffix=os.path.splitext(filename)[1]
+ ) as tmp:
+ tmp.write(content)
+ return tmp.name
+
+ temp_path = await asyncio.to_thread(_write_temp)
+ return temp_path, filename, file_size
+
+ saved_files = await asyncio.gather(*(_read_and_save(f) for f in files))
+
+ actual_total_size = sum(size for _, _, size in saved_files)
+ if actual_total_size > MAX_TOTAL_SIZE_BYTES:
+ for temp_path, _, _ in saved_files:
+ os.unlink(temp_path)
+ raise HTTPException(
+ status_code=413,
+ detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
+ f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
+ )
+
+ # ===== PHASE 1: Create pending documents for all files =====
+ created_documents: list[Document] = []
+ files_to_process: list[tuple[Document, str, str]] = []
+ skipped_duplicates = 0
+ duplicate_document_ids: list[int] = []
+
+ for temp_path, filename, file_size in saved_files:
+ try:
+ unique_identifier_hash = generate_unique_identifier_hash(
+ DocumentType.FILE, filename, search_space_id
)
- # Check if document already exists (by unique identifier)
existing = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
if existing:
if DocumentStatus.is_state(existing.status, DocumentStatus.READY):
- # True duplicate — content already indexed, skip
os.unlink(temp_path)
skipped_duplicates += 1
duplicate_document_ids.append(existing.id)
continue
- # Existing document is stuck (failed/pending/processing)
- # Reset it to pending and re-dispatch for processing
existing.status = DocumentStatus.pending()
existing.content = "Processing..."
existing.document_metadata = {
@@ -251,50 +247,45 @@ async def create_documents_file_upload(
}
existing.updated_at = get_current_timestamp()
created_documents.append(existing)
- files_to_process.append(
- (existing, temp_path, file.filename or "unknown")
- )
+ files_to_process.append((existing, temp_path, filename))
continue
- # Create pending document (visible immediately in UI via ElectricSQL)
document = Document(
search_space_id=search_space_id,
- title=file.filename or "Uploaded File",
+ title=filename if filename != "unknown" else "Uploaded File",
document_type=DocumentType.FILE,
document_metadata={
- "FILE_NAME": file.filename,
+ "FILE_NAME": filename,
"file_size": file_size,
"upload_time": datetime.now().isoformat(),
},
- content="Processing...", # Placeholder until processed
- content_hash=unique_identifier_hash, # Temporary, updated when ready
+ content="Processing...",
+ content_hash=unique_identifier_hash,
unique_identifier_hash=unique_identifier_hash,
embedding=None,
- status=DocumentStatus.pending(), # Shows "pending" in UI
+ status=DocumentStatus.pending(),
updated_at=get_current_timestamp(),
created_by_id=str(user.id),
)
session.add(document)
created_documents.append(document)
- files_to_process.append(
- (document, temp_path, file.filename or "unknown")
- )
+ files_to_process.append((document, temp_path, filename))
+ except HTTPException:
+ raise
except Exception as e:
+ os.unlink(temp_path)
raise HTTPException(
status_code=422,
- detail=f"Failed to process file {file.filename}: {e!s}",
+ detail=f"Failed to process file {filename}: {e!s}",
) from e
- # Commit all pending documents - they appear in UI immediately via ElectricSQL
if created_documents:
await session.commit()
- # Refresh to get generated IDs
for doc in created_documents:
await session.refresh(doc)
# ===== PHASE 2: Dispatch tasks for each file =====
- # Each task will update document status: pending → processing → ready/failed
for document, temp_path, filename in files_to_process:
await dispatcher.dispatch_file_processing(
document_id=document.id,
diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py
index c997cba68..e0d78696f 100644
--- a/surfsense_backend/app/routes/new_chat_routes.py
+++ b/surfsense_backend/app/routes/new_chat_routes.py
@@ -32,6 +32,7 @@ from app.db import (
SearchSpace,
User,
get_async_session,
+ shielded_async_session,
)
from app.schemas.new_chat import (
NewChatMessageAppend,
@@ -1092,13 +1093,18 @@ async def handle_new_chat(
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs usable.
await session.commit()
+ # Close the dependency session now so its connection returns to
+ # the pool before streaming begins. Without this, Starlette's
+ # BaseHTTPMiddleware cancels the scope on client disconnect and
+ # the dependency generator's __aexit__ never runs, orphaning the
+ # connection (the "Exception terminating connection" errors).
+ await session.close()
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
- session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
@@ -1323,6 +1329,7 @@ async def regenerate_response(
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
await session.commit()
+ await session.close()
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
# This prevents data loss if streaming fails (network error, LLM error, etc.)
@@ -1333,7 +1340,6 @@ async def regenerate_response(
user_query=user_query_to_use,
search_space_id=request.search_space_id,
chat_id=thread_id,
- session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
@@ -1344,29 +1350,35 @@ async def regenerate_response(
current_user_display_name=user.display_name or "A team member",
):
yield chunk
- # If we get here, streaming completed successfully
streaming_completed = True
finally:
- # Only delete old messages if streaming completed successfully
- # This ensures we don't lose data on streaming failures
- if streaming_completed and messages_to_delete:
+ # Only delete old messages if streaming completed successfully.
+ # Uses a fresh session since stream_new_chat manages its own.
+ if streaming_completed and message_ids_to_delete:
try:
- for msg in messages_to_delete:
- await session.delete(msg)
- await session.commit()
+ async with shielded_async_session() as cleanup_session:
+ for msg_id in message_ids_to_delete:
+ _res = await cleanup_session.execute(
+ select(NewChatMessage).filter(
+ NewChatMessage.id == msg_id
+ )
+ )
+ _msg = _res.scalars().first()
+ if _msg:
+ await cleanup_session.delete(_msg)
+ await cleanup_session.commit()
- # Delete any public snapshots that contain the modified messages
- from app.services.public_chat_service import (
- delete_affected_snapshots,
- )
+ from app.services.public_chat_service import (
+ delete_affected_snapshots,
+ )
- await delete_affected_snapshots(
- session, thread_id, message_ids_to_delete
- )
+ await delete_affected_snapshots(
+ cleanup_session, thread_id, message_ids_to_delete
+ )
except Exception as cleanup_error:
- # Log but don't fail - the new messages are already streamed
- print(
- f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
+ _logger.warning(
+ "[regenerate] Failed to delete old messages: %s",
+ cleanup_error,
)
# Return streaming response with checkpoint_id for rewinding
@@ -1440,13 +1452,13 @@ async def resume_chat(
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
await session.commit()
+ await session.close()
return StreamingResponse(
stream_resume_chat(
chat_id=thread_id,
search_space_id=request.search_space_id,
decisions=decisions,
- session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,
diff --git a/surfsense_backend/app/schemas/chat_comments.py b/surfsense_backend/app/schemas/chat_comments.py
index b87ee58a4..984e8b812 100644
--- a/surfsense_backend/app/schemas/chat_comments.py
+++ b/surfsense_backend/app/schemas/chat_comments.py
@@ -87,6 +87,18 @@ class CommentListResponse(BaseModel):
total_count: int
+class CommentBatchRequest(BaseModel):
+ """Request for batch-fetching comments for multiple messages."""
+
+ message_ids: list[int] = Field(..., min_length=1, max_length=200)
+
+
+class CommentBatchResponse(BaseModel):
+ """Batch response keyed by message_id."""
+
+ comments_by_message: dict[int, CommentListResponse]
+
+
# =============================================================================
# Mention Schemas
# =============================================================================
diff --git a/surfsense_backend/app/services/chat_comments_service.py b/surfsense_backend/app/services/chat_comments_service.py
index c9ca920f6..c2bb65aee 100644
--- a/surfsense_backend/app/services/chat_comments_service.py
+++ b/surfsense_backend/app/services/chat_comments_service.py
@@ -22,6 +22,7 @@ from app.db import (
)
from app.schemas.chat_comments import (
AuthorResponse,
+ CommentBatchResponse,
CommentListResponse,
CommentReplyResponse,
CommentResponse,
@@ -264,6 +265,146 @@ async def get_comments_for_message(
)
+async def get_comments_for_messages_batch(
+ session: AsyncSession,
+ message_ids: list[int],
+ user: User,
+) -> CommentBatchResponse:
+ """
+ Batch-fetch comments for multiple messages in a single DB round-trip.
+
+ Validates that all messages exist and belong to search spaces the user
+ can read comments in, then loads all comments with eager-loaded authors
+ and replies.
+ """
+ if not message_ids:
+ return CommentBatchResponse(comments_by_message={})
+
+ unique_ids = list(set(message_ids))
+
+ result = await session.execute(
+ select(NewChatMessage)
+ .options(selectinload(NewChatMessage.thread))
+ .filter(NewChatMessage.id.in_(unique_ids))
+ )
+ messages = result.scalars().all()
+ msg_map = {m.id: m for m in messages}
+
+ search_space_ids = {m.thread.search_space_id for m in messages}
+ permissions_cache: dict[int, set] = {}
+ for ss_id in search_space_ids:
+ await check_permission(
+ session,
+ user,
+ ss_id,
+ Permission.COMMENTS_READ.value,
+ "You don't have permission to read comments in this search space",
+ )
+ permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id)
+
+ result = await session.execute(
+ select(ChatComment)
+ .options(
+ selectinload(ChatComment.author),
+ selectinload(ChatComment.replies).selectinload(ChatComment.author),
+ )
+ .filter(
+ ChatComment.message_id.in_(unique_ids),
+ ChatComment.parent_id.is_(None),
+ )
+ .order_by(ChatComment.created_at)
+ )
+ top_level_comments = result.scalars().all()
+
+ all_mentioned_uuids: set[UUID] = set()
+ for comment in top_level_comments:
+ all_mentioned_uuids.update(parse_mentions(comment.content))
+ for reply in comment.replies:
+ all_mentioned_uuids.update(parse_mentions(reply.content))
+
+ user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
+
+ comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids}
+ for comment in top_level_comments:
+ comments_by_msg.setdefault(comment.message_id, []).append(comment)
+
+ comments_by_message: dict[int, CommentListResponse] = {}
+ for mid in unique_ids:
+ msg = msg_map.get(mid)
+ if msg is None:
+ comments_by_message[mid] = CommentListResponse(comments=[], total_count=0)
+ continue
+
+ ss_id = msg.thread.search_space_id
+ user_perms = permissions_cache.get(ss_id, set())
+ can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value)
+
+ comment_responses = []
+ for comment in comments_by_msg.get(mid, []):
+ author = None
+ if comment.author:
+ author = AuthorResponse(
+ id=comment.author.id,
+ display_name=comment.author.display_name,
+ avatar_url=comment.author.avatar_url,
+ email=comment.author.email,
+ )
+
+ replies = []
+ for reply in sorted(comment.replies, key=lambda r: r.created_at):
+ reply_author = None
+ if reply.author:
+ reply_author = AuthorResponse(
+ id=reply.author.id,
+ display_name=reply.author.display_name,
+ avatar_url=reply.author.avatar_url,
+ email=reply.author.email,
+ )
+ is_reply_author = (
+ reply.author_id == user.id if reply.author_id else False
+ )
+ replies.append(
+ CommentReplyResponse(
+ id=reply.id,
+ content=reply.content,
+ content_rendered=render_mentions(reply.content, user_names),
+ author=reply_author,
+ created_at=reply.created_at,
+ updated_at=reply.updated_at,
+ is_edited=reply.updated_at > reply.created_at,
+ can_edit=is_reply_author,
+ can_delete=is_reply_author or can_delete_any,
+ )
+ )
+
+ is_comment_author = (
+ comment.author_id == user.id if comment.author_id else False
+ )
+ comment_responses.append(
+ CommentResponse(
+ id=comment.id,
+ message_id=comment.message_id,
+ content=comment.content,
+ content_rendered=render_mentions(comment.content, user_names),
+ author=author,
+ created_at=comment.created_at,
+ updated_at=comment.updated_at,
+ is_edited=comment.updated_at > comment.created_at,
+ can_edit=is_comment_author,
+ can_delete=is_comment_author or can_delete_any,
+ reply_count=len(replies),
+ replies=replies,
+ )
+ )
+
+ comments_by_message[mid] = CommentListResponse(
+ comments=comment_responses,
+ total_count=len(comment_responses),
+ )
+
+ return CommentBatchResponse(comments_by_message=comments_by_message)
+
+
async def create_comment(
session: AsyncSession,
message_id: int,
diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py
index 3bd9a4421..0aa48eccd 100644
--- a/surfsense_backend/app/services/connector_service.py
+++ b/surfsense_backend/app/services/connector_service.py
@@ -1,4 +1,5 @@
import asyncio
+import time
from datetime import datetime
from typing import Any
from urllib.parse import urljoin
@@ -15,9 +16,11 @@ from app.db import (
Document,
SearchSourceConnector,
SearchSourceConnectorType,
+ async_session_maker,
)
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever
+from app.utils.perf import get_perf_logger
class ConnectorService:
@@ -221,6 +224,7 @@ class ConnectorService:
top_k: int = 20,
start_date: datetime | None = None,
end_date: datetime | None = None,
+ query_embedding: list[float] | None = None,
) -> list[dict[str, Any]]:
"""
Perform combined search using both chunk-based and document-based hybrid search,
@@ -246,34 +250,60 @@ class ConnectorService:
Returns:
List of combined and deduplicated document results
"""
+ from app.config import config
+
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+
# RRF constant
k = 60
# Get more results from each retriever for better fusion
retriever_top_k = top_k * 2
- # IMPORTANT:
- # These retrievers share the same AsyncSession. AsyncSession does not permit
- # concurrent awaits that require DB IO on the same session/connection.
- # Running these in parallel can raise:
- # "This session is provisioning a new connection; concurrent operations are not permitted"
- #
- # So we run them sequentially.
- chunk_results = await self.chunk_retriever.hybrid_search(
- query_text=query_text,
- top_k=retriever_top_k,
- search_space_id=search_space_id,
- document_type=document_type,
- start_date=start_date,
- end_date=end_date,
+ # Reuse caller-provided embedding or compute once for both retrievers.
+ if query_embedding is None:
+ t_embed = time.perf_counter()
+ query_embedding = config.embedding_model_instance.embed(query_text)
+ perf.info(
+ "[connector_svc] _combined_rrf embedding in %.3fs type=%s",
+ time.perf_counter() - t_embed,
+ document_type,
+ )
+
+ search_kwargs = {
+ "query_text": query_text,
+ "top_k": retriever_top_k,
+ "search_space_id": search_space_id,
+ "document_type": document_type,
+ "start_date": start_date,
+ "end_date": end_date,
+ "query_embedding": query_embedding,
+ }
+
+ # Run chunk and document retrievers in parallel using separate DB sessions
+ # so they don't contend on a shared AsyncSession connection.
+ async def _run_chunk_search() -> list[dict[str, Any]]:
+ async with async_session_maker() as session:
+ retriever = ChucksHybridSearchRetriever(session)
+ return await retriever.hybrid_search(**search_kwargs)
+
+ async def _run_doc_search() -> list[dict[str, Any]]:
+ async with async_session_maker() as session:
+ retriever = DocumentHybridSearchRetriever(session)
+ return await retriever.hybrid_search(**search_kwargs)
+
+ t_parallel = time.perf_counter()
+ chunk_results, doc_results = await asyncio.gather(
+ _run_chunk_search(), _run_doc_search()
)
- doc_results = await self.document_retriever.hybrid_search(
- query_text=query_text,
- top_k=retriever_top_k,
- search_space_id=search_space_id,
- document_type=document_type,
- start_date=start_date,
- end_date=end_date,
+ perf.info(
+ "[connector_svc] _combined_rrf parallel retrievers in %.3fs "
+ "chunk_results=%d doc_results=%d type=%s",
+ time.perf_counter() - t_parallel,
+ len(chunk_results),
+ len(doc_results),
+ document_type,
)
# Helper to extract document_id from our doc-grouped result
@@ -335,6 +365,13 @@ class ConnectorService:
result["chunks"] = doc_data[did]["chunks"]
combined_results.append(result)
+ perf.info(
+ "[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d",
+ time.perf_counter() - t0,
+ len(combined_results),
+ document_type,
+ search_space_id,
+ )
return combined_results
def _get_doc_url(self, metadata: dict[str, Any]) -> str:
diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py
index 2e517f0ba..e8c0d2d47 100644
--- a/surfsense_backend/app/services/llm_router_service.py
+++ b/surfsense_backend/app/services/llm_router_service.py
@@ -13,8 +13,10 @@ synchronous ChatLiteLLM-like interface and async methods.
import logging
import re
+import time
from typing import Any
+import litellm
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import BaseChatModel
@@ -26,6 +28,11 @@ from litellm.exceptions import (
ContextWindowExceededError,
)
+from app.utils.perf import get_perf_logger
+
+litellm.json_logs = False
+litellm.store_audit_logs = False
+
logger = logging.getLogger(__name__)
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
@@ -152,26 +159,95 @@ class LLMRouterService:
# Merge with provided settings
final_settings = {**default_settings, **instance._router_settings}
+ # Build a "auto-large" fallback group with deployments whose context
+ # window exceeds the smallest deployment. This lets the router
+ # automatically fall back to a bigger-context model when gpt-4o (128K)
+ # hits ContextWindowExceededError.
+ full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
+
try:
- instance._router = Router(
- model_list=model_list,
- routing_strategy=final_settings.get(
+ router_kwargs: dict[str, Any] = {
+ "model_list": full_model_list,
+ "routing_strategy": final_settings.get(
"routing_strategy", "usage-based-routing"
),
- num_retries=final_settings.get("num_retries", 3),
- allowed_fails=final_settings.get("allowed_fails", 3),
- cooldown_time=final_settings.get("cooldown_time", 60),
- set_verbose=False, # Disable verbose logging in production
- )
+ "num_retries": final_settings.get("num_retries", 3),
+ "allowed_fails": final_settings.get("allowed_fails", 3),
+ "cooldown_time": final_settings.get("cooldown_time", 60),
+ "set_verbose": False,
+ }
+ if ctx_fallbacks:
+ router_kwargs["context_window_fallbacks"] = ctx_fallbacks
+
+ instance._router = Router(**router_kwargs)
instance._initialized = True
logger.info(
- f"LLM Router initialized with {len(model_list)} deployments, "
- f"strategy: {final_settings.get('routing_strategy')}"
+ "LLM Router initialized with %d deployments, "
+ "strategy: %s, context_window_fallbacks: %s",
+ len(model_list),
+ final_settings.get("routing_strategy"),
+ ctx_fallbacks or "none",
)
except Exception as e:
logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None
+ @classmethod
+ def _build_context_fallback_groups(
+ cls, model_list: list[dict]
+ ) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
+ """Create an ``auto-large`` model group for context-window fallbacks.
+
+ Uses ``litellm.get_model_info`` to discover the context window of each
+ deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
+ window are duplicated into an ``auto-large`` group. The returned
+ fallback config tells the Router: on ``ContextWindowExceededError`` for
+ ``auto``, retry with ``auto-large``.
+
+ Returns:
+ (full_model_list, context_window_fallbacks) — ``full_model_list``
+ contains the original entries plus any ``auto-large`` duplicates.
+ ``context_window_fallbacks`` is ``None`` when every deployment has
+ the same context size (no useful fallback).
+ """
+ from litellm import get_model_info
+
+ ctx_map: dict[str, int] = {}
+ for dep in model_list:
+ params = dep.get("litellm_params", {})
+ base_model = params.get("base_model") or params.get("model", "")
+ try:
+ info = get_model_info(base_model)
+ ctx = info.get("max_input_tokens")
+ if isinstance(ctx, int) and ctx > 0:
+ ctx_map[base_model] = ctx
+ except Exception:
+ continue
+
+ if not ctx_map:
+ return model_list, None
+
+ min_ctx = min(ctx_map.values())
+
+ large_deployments: list[dict] = []
+ for dep in model_list:
+ params = dep.get("litellm_params", {})
+ base_model = params.get("base_model") or params.get("model", "")
+ if ctx_map.get(base_model, 0) > min_ctx:
+ dup = {**dep, "model_name": "auto-large"}
+ large_deployments.append(dup)
+
+ if not large_deployments:
+ return model_list, None
+
+ logger.info(
+ "Context-window fallback: %d large-context deployments "
+ "(min_ctx=%d) added to 'auto-large' group",
+ len(large_deployments),
+ min_ctx,
+ )
+ return model_list + large_deployments, [{"auto": ["auto-large"]}]
+
@classmethod
def _config_to_deployment(cls, config: dict) -> dict | None:
"""
@@ -247,6 +323,48 @@ class LLMRouterService:
return len(instance._model_list)
+_cached_context_profile: dict | None = None
+_cached_context_profile_computed: bool = False
+
+# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
+_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
+
+
+def _get_cached_context_profile(router: Router) -> dict | None:
+ """Compute and cache the min context profile across all router deployments.
+
+ Called once on first ChatLiteLLMRouter creation; subsequent calls return
+ the cached value. This avoids calling litellm.get_model_info() for every
+ deployment on every request.
+ """
+ global _cached_context_profile, _cached_context_profile_computed
+ if _cached_context_profile_computed:
+ return _cached_context_profile
+
+ from litellm import get_model_info
+
+ min_ctx: int | None = None
+ for deployment in router.model_list:
+ params = deployment.get("litellm_params", {})
+ base_model = params.get("base_model") or params.get("model", "")
+ try:
+ info = get_model_info(base_model)
+ ctx = info.get("max_input_tokens")
+ if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
+ min_ctx = ctx
+ except Exception:
+ continue
+
+ if min_ctx is not None:
+ logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
+ _cached_context_profile = {"max_input_tokens": min_ctx}
+ else:
+ _cached_context_profile = None
+
+ _cached_context_profile_computed = True
+ return _cached_context_profile
+
+
class ChatLiteLLMRouter(BaseChatModel):
"""
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
@@ -257,6 +375,10 @@ class ChatLiteLLMRouter(BaseChatModel):
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
window across all router deployments so that deepagents
SummarizationMiddleware can use fraction-based triggers.
+
+ **Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
+ directly — instances without bound tools are cached per streaming flag to
+ avoid per-request re-initialization overhead and memory growth.
"""
# Use model_config for Pydantic v2 compatibility
@@ -278,14 +400,6 @@ class ChatLiteLLMRouter(BaseChatModel):
tool_choice: str | dict | None = None,
**kwargs,
):
- """
- Initialize the ChatLiteLLMRouter.
-
- Args:
- router: LiteLLM Router instance. If None, uses the global singleton.
- bound_tools: Pre-bound tools for tool calling
- tool_choice: Tool choice configuration
- """
try:
super().__init__(**kwargs)
resolved_router = router or LLMRouterService.get_router()
@@ -297,51 +411,20 @@ class ChatLiteLLMRouter(BaseChatModel):
"LLM Router not initialized. Call LLMRouterService.initialize() first."
)
- # Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
- computed_profile = self._compute_min_context_profile()
+ computed_profile = _get_cached_context_profile(self._router)
if computed_profile is not None:
object.__setattr__(self, "profile", computed_profile)
- logger.info(
- f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
+ logger.debug(
+ "ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
+ LLMRouterService.get_model_count(),
+ self.streaming,
+ bound_tools is not None,
)
except Exception as e:
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
raise
- def _compute_min_context_profile(self) -> dict | None:
- """Derive a profile dict with max_input_tokens from router deployments.
-
- Uses litellm.get_model_info to look up each deployment's context window
- and picks the *minimum* so that summarization triggers before ANY model
- in the pool overflows.
- """
- from litellm import get_model_info
-
- if not self._router:
- return None
-
- min_ctx: int | None = None
- for deployment in self._router.model_list:
- params = deployment.get("litellm_params", {})
- base_model = params.get("base_model") or params.get("model", "")
- try:
- info = get_model_info(base_model)
- ctx = info.get("max_input_tokens")
- if (
- isinstance(ctx, int)
- and ctx > 0
- and (min_ctx is None or ctx < min_ctx)
- ):
- min_ctx = ctx
- except Exception:
- continue
-
- if min_ctx is not None:
- logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
- return {"max_input_tokens": min_ctx}
- return None
-
@property
def _llm_type(self) -> str:
return "litellm-router"
@@ -410,6 +493,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+ msg_count = len(messages)
+
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
@@ -428,12 +515,30 @@ class ChatLiteLLMRouter(BaseChatModel):
**call_kwargs,
)
except ContextWindowExceededError as e:
+ perf.warning(
+ "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
+ msg_count,
+ time.perf_counter() - t0,
+ )
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
+ perf.warning(
+ "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
+ msg_count,
+ time.perf_counter() - t0,
+ )
raise ContextOverflowError(str(e)) from e
raise
+ elapsed = time.perf_counter() - t0
+ perf.info(
+ "[llm_router] _generate completed msgs=%d tools=%d in %.3fs",
+ msg_count,
+ len(self._bound_tools) if self._bound_tools else 0,
+ elapsed,
+ )
+
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
generation = ChatGeneration(message=message)
@@ -453,6 +558,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+ msg_count = len(messages)
+
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
@@ -471,12 +580,30 @@ class ChatLiteLLMRouter(BaseChatModel):
**call_kwargs,
)
except ContextWindowExceededError as e:
+ perf.warning(
+ "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
+ msg_count,
+ time.perf_counter() - t0,
+ )
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
+ perf.warning(
+ "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
+ msg_count,
+ time.perf_counter() - t0,
+ )
raise ContextOverflowError(str(e)) from e
raise
+ elapsed = time.perf_counter() - t0
+ perf.info(
+ "[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs",
+ msg_count,
+ len(self._bound_tools) if self._bound_tools else 0,
+ elapsed,
+ )
+
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
generation = ChatGeneration(message=message)
@@ -541,6 +668,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
+ perf = get_perf_logger()
+ t0 = time.perf_counter()
+ msg_count = len(messages)
+
formatted_messages = self._convert_messages(messages)
# Add tools if bound
@@ -559,20 +690,52 @@ class ChatLiteLLMRouter(BaseChatModel):
**call_kwargs,
)
except ContextWindowExceededError as e:
+ perf.warning(
+ "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
+ msg_count,
+ time.perf_counter() - t0,
+ )
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
+ perf.warning(
+ "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
+ msg_count,
+ time.perf_counter() - t0,
+ )
raise ContextOverflowError(str(e)) from e
raise
- # Yield chunks asynchronously
+ t_first_chunk = time.perf_counter()
+ perf.info(
+ "[llm_router] _astream connection established msgs=%d in %.3fs",
+ msg_count,
+ t_first_chunk - t0,
+ )
+
+ chunk_count = 0
+ first_chunk_logged = False
async for chunk in response:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
chunk_msg = self._convert_delta_to_chunk(delta)
if chunk_msg:
+ chunk_count += 1
+ if not first_chunk_logged:
+ perf.info(
+ "[llm_router] _astream first chunk in %.3fs (total %.3fs from start)",
+ time.perf_counter() - t_first_chunk,
+ time.perf_counter() - t0,
+ )
+ first_chunk_logged = True
yield ChatGenerationChunk(message=chunk_msg)
+ perf.info(
+ "[llm_router] _astream completed chunks=%d total=%.3fs",
+ chunk_count,
+ time.perf_counter() - t0,
+ )
+
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
"""Convert LangChain messages to OpenAI format."""
from langchain_core.messages import (
@@ -687,19 +850,28 @@ class ChatLiteLLMRouter(BaseChatModel):
return None
-def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
- """
- Get a ChatLiteLLMRouter instance for auto mode.
+def get_auto_mode_llm(
+ *,
+ streaming: bool = True,
+) -> ChatLiteLLMRouter | None:
+ """Return a cached ChatLiteLLMRouter for auto mode.
- Returns:
- ChatLiteLLMRouter instance or None if router not initialized
+ Base (no tools) instances are cached per ``streaming`` flag so we
+ avoid re-constructing them on every request. ``bind_tools()`` still
+ returns a fresh instance because bound tools differ per agent.
"""
if not LLMRouterService.is_initialized():
logger.warning("LLM Router not initialized for auto mode")
return None
+ cached = _router_instance_cache.get(streaming)
+ if cached is not None:
+ return cached
+
try:
- return ChatLiteLLMRouter()
+ instance = ChatLiteLLMRouter(streaming=streaming)
+ _router_instance_cache[streaming] = instance
+ return instance
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None
diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py
index 4833e62a6..fc28f477f 100644
--- a/surfsense_backend/app/services/llm_service.py
+++ b/surfsense_backend/app/services/llm_service.py
@@ -12,12 +12,20 @@ from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
+ get_auto_mode_llm,
is_auto_mode,
)
# Configure litellm to automatically drop unsupported parameters
litellm.drop_params = True
+# Memory controls: prevent unbounded internal accumulation
+litellm.telemetry = False
+litellm.cache = None
+litellm.success_callback = []
+litellm.failure_callback = []
+litellm.input_callback = []
+
logger = logging.getLogger(__name__)
@@ -221,7 +229,7 @@ async def get_search_space_llm_instance(
logger.debug(
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
)
- return ChatLiteLLMRouter(disable_streaming=disable_streaming)
+ return get_auto_mode_llm(streaming=not disable_streaming)
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None
diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py
index 9abc472fe..5b1f2cd13 100644
--- a/surfsense_backend/app/tasks/celery_tasks/__init__.py
+++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py
@@ -1 +1,28 @@
"""Celery tasks package."""
+
+from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
+from sqlalchemy.pool import NullPool
+
+from app.config import config
+
+_celery_engine = None
+_celery_session_maker = None
+
+
+def get_celery_session_maker() -> async_sessionmaker:
+ """Return a shared async session maker for Celery tasks.
+
+ A single NullPool engine is created per worker process and reused
+ across all task invocations to avoid leaking engine objects.
+ """
+ global _celery_engine, _celery_session_maker
+ if _celery_session_maker is None:
+ _celery_engine = create_async_engine(
+ config.DATABASE_URL,
+ poolclass=NullPool,
+ echo=False,
+ )
+ _celery_session_maker = async_sessionmaker(
+ _celery_engine, expire_on_commit=False
+ )
+ return _celery_session_maker
diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
index a35528a93..9d52add9c 100644
--- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py
@@ -3,11 +3,8 @@
import logging
import traceback
-from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
-from sqlalchemy.pool import NullPool
-
from app.celery_app import celery_app
-from app.config import config
+from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@@ -42,20 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
)
-def get_celery_session_maker():
- """
- Create a new async session maker for Celery tasks.
- This is necessary because Celery tasks run in a new event loop,
- and the default session maker is bound to the main app's event loop.
- """
- engine = create_async_engine(
- config.DATABASE_URL,
- poolclass=NullPool, # Don't use connection pooling for Celery tasks
- echo=False,
- )
- return async_sessionmaker(engine, expire_on_commit=False)
-
-
@celery_app.task(name="index_slack_messages", bind=True)
def index_slack_messages_task(
self,
diff --git a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
index fd099684d..c2dbe7700 100644
--- a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py
@@ -4,30 +4,18 @@ import logging
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
-from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import selectinload
-from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
-from app.config import config
from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
+from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
-def get_celery_session_maker():
- """Create async session maker for Celery tasks."""
- engine = create_async_engine(
- config.DATABASE_URL,
- poolclass=NullPool,
- echo=False,
- )
- return async_sessionmaker(engine, expire_on_commit=False)
-
-
@celery_app.task(name="reindex_document", bind=True)
def reindex_document_task(self, document_id: int, user_id: str):
"""
diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
index 60cd21f97..dcb791d3b 100644
--- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py
@@ -5,13 +5,11 @@ import logging
import os
from uuid import UUID
-from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
-from sqlalchemy.pool import NullPool
-
from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
+from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.document_processors import (
add_extension_received_document,
add_youtube_video_document,
@@ -91,20 +89,6 @@ async def _run_heartbeat_loop(notification_id: int):
pass # Normal cancellation when task completes
-def get_celery_session_maker():
- """
- Create a new async session maker for Celery tasks.
- This is necessary because Celery tasks run in a new event loop,
- and the default session maker is bound to the main app's event loop.
- """
- engine = create_async_engine(
- config.DATABASE_URL,
- poolclass=NullPool, # Don't use connection pooling for Celery tasks
- echo=False,
- )
- return async_sessionmaker(engine, expire_on_commit=False)
-
-
@celery_app.task(name="process_extension_document", bind=True)
def process_extension_document_task(
self, individual_document_dict, search_space_id: int, user_id: str
diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
index 973e7e750..42378fe5e 100644
--- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
+++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
@@ -5,14 +5,13 @@ import logging
import sys
from sqlalchemy import select
-from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
-from sqlalchemy.pool import NullPool
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config
from app.db import Podcast, PodcastStatus
+from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@@ -25,20 +24,6 @@ if sys.platform.startswith("win"):
)
-def get_celery_session_maker():
- """
- Create a new async session maker for Celery tasks.
- This is necessary because Celery tasks run in a new event loop,
- and the default session maker is bound to the main app's event loop.
- """
- engine = create_async_engine(
- config.DATABASE_URL,
- poolclass=NullPool, # Don't use connection pooling for Celery tasks
- echo=False,
- )
- return async_sessionmaker(engine, expire_on_commit=False)
-
-
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================
diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
index 80d271aaa..0ba8bc80a 100644
--- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
+++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py
@@ -3,28 +3,16 @@
import logging
from datetime import UTC, datetime
-from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select
-from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
-from app.config import config
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
+from app.tasks.celery_tasks import get_celery_session_maker
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
-def get_celery_session_maker():
- """Create async session maker for Celery tasks."""
- engine = create_async_engine(
- config.DATABASE_URL,
- poolclass=NullPool,
- echo=False,
- )
- return async_sessionmaker(engine, expire_on_commit=False)
-
-
@celery_app.task(name="check_periodic_schedules")
def check_periodic_schedules_task():
"""
diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
index c2c82dd2c..e05ae9435 100644
--- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
+++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py
@@ -29,20 +29,17 @@ from datetime import UTC, datetime
import redis
from sqlalchemy import and_, or_, text
-from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select
-from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Document, DocumentStatus, Notification
+from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
-# Redis client for checking heartbeats
_redis_client: redis.Redis | None = None
-# Error messages shown to users when tasks are interrupted
STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry."
STALE_PROCESSING_ERROR_MESSAGE = "Syncing was interrupted unexpectedly. Please retry."
@@ -60,16 +57,6 @@ def _get_heartbeat_key(notification_id: int) -> str:
return f"indexing:heartbeat:{notification_id}"
-def get_celery_session_maker():
- """Create async session maker for Celery tasks."""
- engine = create_async_engine(
- config.DATABASE_URL,
- poolclass=NullPool,
- echo=False,
- )
- return async_sessionmaker(engine, expire_on_commit=False)
-
-
@celery_app.task(name="cleanup_stale_indexing_notifications")
def cleanup_stale_indexing_notifications_task():
"""
diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py
index ddadbc48b..8d09ff387 100644
--- a/surfsense_backend/app/tasks/chat/stream_new_chat.py
+++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py
@@ -10,6 +10,8 @@ Supports loading LLM configurations from:
"""
import asyncio
+import contextlib
+import gc
import json
import logging
import re
@@ -19,9 +21,9 @@ from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
+import anyio
from langchain_core.messages import HumanMessage
from sqlalchemy import func
-from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
@@ -47,6 +49,7 @@ from app.db import (
SearchSourceConnectorType,
SurfsenseDocsDocument,
async_session_maker,
+ shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.services.chat_session_state_service import (
@@ -56,14 +59,9 @@ from app.services.chat_session_state_service import (
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
+from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
-_perf_log = logging.getLogger("surfsense.perf")
-_perf_log.setLevel(logging.DEBUG)
-if not _perf_log.handlers:
- _h = logging.StreamHandler()
- _h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
- _perf_log.addHandler(_h)
- _perf_log.propagate = False
+_perf_log = get_perf_logger()
_background_tasks: set[asyncio.Task] = set()
@@ -1016,7 +1014,6 @@ async def stream_new_chat(
user_query: str,
search_space_id: int,
chat_id: int,
- session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
@@ -1032,11 +1029,13 @@ async def stream_new_chat(
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
+ The function creates and manages its own database session to guarantee proper
+ cleanup even when Starlette's middleware cancels the task on client disconnect.
+
Args:
user_query: The user's query
search_space_id: The search space ID
chat_id: The chat ID (used as LangGraph thread_id for memory)
- session: The database session
user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config)
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
@@ -1050,7 +1049,9 @@ async def stream_new_chat(
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
+ log_system_snapshot("stream_new_chat_START")
+ session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
if user_id:
@@ -1286,6 +1287,12 @@ async def stream_new_chat(
# short-lived transactions (or use isolated sessions).
await session.commit()
+ # Detach heavy ORM objects (documents with chunks, reports, etc.)
+ # from the session identity map now that we've extracted the data
+ # we need. This prevents them from accumulating in memory for the
+ # entire duration of LLM streaming (which can be several minutes).
+ session.expunge_all()
+
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
@@ -1353,6 +1360,12 @@ async def stream_new_chat(
items=initial_items,
)
+ # These ORM objects (with eagerly-loaded chunks) can be very large.
+ # They're only needed to build context strings already copied into
+ # final_query / langchain_messages — release them before streaming.
+ del mentioned_documents, mentioned_surfsense_docs, recent_reports
+ del langchain_messages, final_query
+
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
@@ -1382,6 +1395,7 @@ async def stream_new_chat(
time.perf_counter() - _t_stream_start,
chat_id,
)
+ log_system_snapshot("stream_new_chat_END")
if stream_result.is_interrupted:
yield streaming_service.format_finish_step()
@@ -1461,30 +1475,57 @@ async def stream_new_chat(
yield streaming_service.format_done()
finally:
- # Clear AI responding state for live collaboration.
- # The original session may be broken (client disconnect / CancelledError
- # can corrupt the underlying DB connection), so we try a rollback first
- # and fall back to a fresh session if the original is unusable.
- try:
- await session.rollback()
- await clear_ai_responding(session, chat_id)
- except Exception:
+ # Shield the ENTIRE async cleanup from anyio cancel-scope
+ # cancellation. Starlette's BaseHTTPMiddleware uses anyio task
+ # groups; on client disconnect, it cancels the scope with
+ # level-triggered cancellation — every unshielded `await` inside
+ # the cancelled scope raises CancelledError immediately. Without
+ # this shield the very first `await` (session.rollback) would
+ # raise CancelledError, `except Exception` wouldn't catch it
+ # (CancelledError is a BaseException), and the rest of the
+ # finally block — including session.close() — would never run.
+ with anyio.CancelScope(shield=True):
try:
- async with async_session_maker() as fresh_session:
- await clear_ai_responding(fresh_session, chat_id)
+ await session.rollback()
+ await clear_ai_responding(session, chat_id)
except Exception:
- logging.getLogger(__name__).warning(
- "Failed to clear AI responding state for thread %s", chat_id
- )
+ try:
+ async with shielded_async_session() as fresh_session:
+ await clear_ai_responding(fresh_session, chat_id)
+ except Exception:
+ logging.getLogger(__name__).warning(
+ "Failed to clear AI responding state for thread %s", chat_id
+ )
- _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
+ _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
+
+ with contextlib.suppress(Exception):
+ session.expunge_all()
+
+ with contextlib.suppress(Exception):
+ await session.close()
+
+ # Break circular refs held by the agent graph, tools, and LLM
+ # wrappers so the GC can reclaim them in a single pass.
+ agent = llm = connector_service = sandbox_backend = None
+ input_state = stream_result = None
+ session = None
+
+ collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
+ if collected:
+ _perf_log.info(
+ "[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
+ collected,
+ chat_id,
+ )
+ trim_native_heap()
+ log_system_snapshot("stream_new_chat_END")
async def stream_resume_chat(
chat_id: int,
search_space_id: int,
decisions: list[dict],
- session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None,
@@ -1493,6 +1534,7 @@ async def stream_resume_chat(
stream_result = StreamResult()
_t_total = time.perf_counter()
+ session = async_session_maker()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
@@ -1589,6 +1631,7 @@ async def stream_resume_chat(
# Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit()
+ session.expunge_all()
_perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
@@ -1652,16 +1695,37 @@ async def stream_resume_chat(
yield streaming_service.format_done()
finally:
- try:
- await session.rollback()
- await clear_ai_responding(session, chat_id)
- except Exception:
+ with anyio.CancelScope(shield=True):
try:
- async with async_session_maker() as fresh_session:
- await clear_ai_responding(fresh_session, chat_id)
+ await session.rollback()
+ await clear_ai_responding(session, chat_id)
except Exception:
- logging.getLogger(__name__).warning(
- "Failed to clear AI responding state for thread %s", chat_id
- )
+ try:
+ async with shielded_async_session() as fresh_session:
+ await clear_ai_responding(fresh_session, chat_id)
+ except Exception:
+ logging.getLogger(__name__).warning(
+ "Failed to clear AI responding state for thread %s", chat_id
+ )
- _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
+ _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
+
+ with contextlib.suppress(Exception):
+ session.expunge_all()
+
+ with contextlib.suppress(Exception):
+ await session.close()
+
+ agent = llm = connector_service = sandbox_backend = None
+ stream_result = None
+ session = None
+
+ collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
+ if collected:
+ _perf_log.info(
+ "[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
+ collected,
+ chat_id,
+ )
+ trim_native_heap()
+ log_system_snapshot("stream_resume_chat_END")
diff --git a/surfsense_backend/app/utils/perf.py b/surfsense_backend/app/utils/perf.py
new file mode 100644
index 000000000..b2b26897c
--- /dev/null
+++ b/surfsense_backend/app/utils/perf.py
@@ -0,0 +1,174 @@
+"""
+Centralized performance monitoring for SurfSense backend.
+
+Provides:
+- A shared [PERF] logger used across all modules
+- perf_timer context manager for timing code blocks
+- perf_async_timer for async code blocks
+- system_snapshot() for CPU/memory profiling
+- RequestPerfMiddleware for per-request timing
+"""
+
+import gc
+import logging
+import os
+import time
+from contextlib import asynccontextmanager, contextmanager
+from typing import Any
+
+_perf_log: logging.Logger | None = None
+_last_rss_mb: float = 0.0
+
+
+def get_perf_logger() -> logging.Logger:
+ """Return the singleton [PERF] logger, creating it once on first call."""
+ global _perf_log
+ if _perf_log is None:
+ _perf_log = logging.getLogger("surfsense.perf")
+ _perf_log.setLevel(logging.DEBUG)
+ if not _perf_log.handlers:
+ h = logging.StreamHandler()
+ h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
+ _perf_log.addHandler(h)
+ _perf_log.propagate = False
+ return _perf_log
+
+
+@contextmanager
+def perf_timer(label: str, *, extra: dict[str, Any] | None = None):
+ """Synchronous context manager that logs elapsed wall-clock time.
+
+ Usage:
+ with perf_timer("[my_func] heavy computation"):
+ ...
+ """
+ log = get_perf_logger()
+ t0 = time.perf_counter()
+ yield
+ elapsed = time.perf_counter() - t0
+ suffix = ""
+ if extra:
+ suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items())
+ log.info("%s in %.3fs%s", label, elapsed, suffix)
+
+
+@asynccontextmanager
+async def perf_async_timer(label: str, *, extra: dict[str, Any] | None = None):
+ """Async context manager that logs elapsed wall-clock time.
+
+ Usage:
+ async with perf_async_timer("[search] vector search"):
+ ...
+ """
+ log = get_perf_logger()
+ t0 = time.perf_counter()
+ yield
+ elapsed = time.perf_counter() - t0
+ suffix = ""
+ if extra:
+ suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items())
+ log.info("%s in %.3fs%s", label, elapsed, suffix)
+
+
+def system_snapshot() -> dict[str, Any]:
+ """Capture a lightweight CPU + memory snapshot of the current process.
+
+ Returns a dict with:
+ - rss_mb: Resident Set Size in MB
+ - rss_delta_mb: Change in RSS since the last snapshot
+ - cpu_percent: CPU usage % since last call (per-process)
+ - threads: number of active threads
+ - open_fds: number of open file descriptors (Linux only)
+ - asyncio_tasks: number of asyncio tasks currently alive
+ - gc_counts: tuple of object counts per gc generation
+ """
+ import asyncio
+
+ global _last_rss_mb
+
+ snapshot: dict[str, Any] = {}
+ try:
+ import psutil
+
+ proc = psutil.Process(os.getpid())
+ mem = proc.memory_info()
+ rss_mb = round(mem.rss / 1024 / 1024, 1)
+ snapshot["rss_mb"] = rss_mb
+ snapshot["rss_delta_mb"] = (
+ round(rss_mb - _last_rss_mb, 1) if _last_rss_mb else 0.0
+ )
+ _last_rss_mb = rss_mb
+ snapshot["cpu_percent"] = proc.cpu_percent(interval=None)
+ snapshot["threads"] = proc.num_threads()
+ try:
+ snapshot["open_fds"] = proc.num_fds()
+ except AttributeError:
+ snapshot["open_fds"] = -1
+ except ImportError:
+ snapshot["rss_mb"] = -1
+ snapshot["rss_delta_mb"] = 0.0
+ snapshot["cpu_percent"] = -1
+ snapshot["threads"] = -1
+ snapshot["open_fds"] = -1
+
+ try:
+ all_tasks = asyncio.all_tasks()
+ snapshot["asyncio_tasks"] = len(all_tasks)
+ except RuntimeError:
+ snapshot["asyncio_tasks"] = -1
+
+ snapshot["gc_counts"] = gc.get_count()
+
+ return snapshot
+
+
+def log_system_snapshot(label: str = "system_snapshot") -> None:
+ """Capture and log a system snapshot with memory delta tracking."""
+ snap = system_snapshot()
+ delta_str = ""
+ if snap["rss_delta_mb"]:
+ sign = "+" if snap["rss_delta_mb"] > 0 else ""
+ delta_str = f" delta={sign}{snap['rss_delta_mb']}MB"
+ get_perf_logger().info(
+ "[%s] rss=%.1fMB%s cpu=%.1f%% threads=%d fds=%d asyncio_tasks=%d gc=%s",
+ label,
+ snap["rss_mb"],
+ delta_str,
+ snap["cpu_percent"],
+ snap["threads"],
+ snap["open_fds"],
+ snap["asyncio_tasks"],
+ snap["gc_counts"],
+ )
+
+ if snap["rss_mb"] > 0 and snap["rss_delta_mb"] > 500:
+ get_perf_logger().warning(
+ "[MEMORY_SPIKE] %s: RSS jumped by %.1fMB (now %.1fMB). "
+ "Possible leak — check recent operations.",
+ label,
+ snap["rss_delta_mb"],
+ snap["rss_mb"],
+ )
+
+
+def trim_native_heap() -> bool:
+ """Ask glibc to return free heap pages to the OS via ``malloc_trim(0)``.
+
+ On Linux (glibc), ``free()`` does not release memory back to the OS if
+ it sits below the brk watermark. ``malloc_trim`` forces the allocator
+ to give back as many freed pages as possible.
+
+ Returns True if trimming was performed, False otherwise (non-Linux or
+ libc unavailable).
+ """
+ import ctypes
+ import sys
+
+ if sys.platform != "linux":
+ return False
+ try:
+ libc = ctypes.CDLL("libc.so.6")
+ libc.malloc_trim(0)
+ return True
+ except (OSError, AttributeError):
+ return False
diff --git a/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx
index f2e9fb731..77ca38c38 100644
--- a/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx
+++ b/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx
@@ -181,7 +181,7 @@ export default function MorePagesPage() {
- eric@surfsense.com
+ rohan@surfsense.com
diff --git a/surfsense_web/app/globals.css b/surfsense_web/app/globals.css
index 11d7d7a94..c192a27be 100644
--- a/surfsense_web/app/globals.css
+++ b/surfsense_web/app/globals.css
@@ -235,4 +235,3 @@ button {
@source '../node_modules/streamdown/dist/*.js';
@source '../node_modules/@streamdown/code/dist/*.js';
@source '../node_modules/@streamdown/math/dist/*.js';
-
diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx
index 98964013d..332694676 100644
--- a/surfsense_web/components/assistant-ui/connector-popup.tsx
+++ b/surfsense_web/components/assistant-ui/connector-popup.tsx
@@ -5,6 +5,7 @@ import { AlertTriangle, Cable, Settings } from "lucide-react";
import Link from "next/link";
import { useSearchParams } from "next/navigation";
import type { FC } from "react";
+import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
import {
globalNewLLMConfigsAtom,
llmPreferencesAtom,
@@ -19,7 +20,6 @@ import { Spinner } from "@/components/ui/spinner";
import { Tabs, TabsContent } from "@/components/ui/tabs";
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
import { useConnectorsElectric } from "@/hooks/use-connectors-electric";
-import { useDocuments } from "@/hooks/use-documents";
import { useInbox } from "@/hooks/use-inbox";
import { cn } from "@/lib/utils";
import { ConnectorDialogHeader } from "./connector-popup/components/connector-dialog-header";
@@ -62,10 +62,9 @@ export const ConnectorIndicator: FC<{ hideTrigger?: boolean }> = ({ hideTrigger
const llmConfigLoading = preferencesLoading || globalConfigsLoading;
- // Fetch document type counts using Electric SQL + PGlite for real-time updates
- const { typeCounts: documentTypeCounts, loading: documentTypesLoading } = useDocuments(
- searchSpaceId ? Number(searchSpaceId) : null
- );
+ // Fetch document type counts via the lightweight /type-counts endpoint (cached 10 min)
+ const { data: documentTypeCounts, isFetching: documentTypesLoading } =
+ useAtomValue(documentTypeCountsAtom);
// Fetch notifications to detect indexing failures
const { inboxItems = [] } = useInbox(
diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts
index a3e8ae272..5deee8360 100644
--- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts
+++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts
@@ -2,27 +2,28 @@ import { EnumConnectorName } from "@/contracts/enums/connector";
// OAuth Connectors (Quick Connect)
export const OAUTH_CONNECTORS = [
- {
- id: "google-drive-connector",
- title: "Google Drive",
- description: "Search your Drive files",
- connectorType: EnumConnectorName.GOOGLE_DRIVE_CONNECTOR,
- authEndpoint: "/api/v1/auth/google/drive/connector/add/",
- },
- {
- id: "google-gmail-connector",
- title: "Gmail",
- description: "Search through your emails",
- connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
- authEndpoint: "/api/v1/auth/google/gmail/connector/add/",
- },
- {
- id: "google-calendar-connector",
- title: "Google Calendar",
- description: "Search through your events",
- connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
- authEndpoint: "/api/v1/auth/google/calendar/connector/add/",
- },
+ // // Uncomment for managed Google Connections
+ // {
+ // id: "google-drive-connector",
+ // title: "Google Drive",
+ // description: "Search your Drive files",
+ // connectorType: EnumConnectorName.GOOGLE_DRIVE_CONNECTOR,
+ // authEndpoint: "/api/v1/auth/google/drive/connector/add/",
+ // },
+ // {
+ // id: "google-gmail-connector",
+ // title: "Gmail",
+ // description: "Search through your emails",
+ // connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
+ // authEndpoint: "/api/v1/auth/google/gmail/connector/add/",
+ // },
+ // {
+ // id: "google-calendar-connector",
+ // title: "Google Calendar",
+ // description: "Search through your events",
+ // connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
+ // authEndpoint: "/api/v1/auth/google/calendar/connector/add/",
+ // },
{
id: "airtable-connector",
title: "Airtable",
diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx
index cd0b4971c..98fa5b436 100644
--- a/surfsense_web/components/assistant-ui/thread.tsx
+++ b/surfsense_web/components/assistant-ui/thread.tsx
@@ -65,6 +65,7 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import type { Document } from "@/contracts/types/document.types";
+import { useBatchCommentsPreload } from "@/hooks/use-comments";
import { useCommentsElectric } from "@/hooks/use-comments-electric";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { cn } from "@/lib/utils";
@@ -309,6 +310,22 @@ const Composer: FC = () => {
// Sync comments for the entire thread via Electric SQL (one subscription per thread)
useCommentsElectric(threadId);
+ // Batch-prefetch comments for all assistant messages so individual useComments
+ // hooks never fire their own network requests (eliminates N+1 API calls).
+ // Return a primitive string from the selector so useSyncExternalStore can
+ // compare snapshots by value and avoid infinite re-render loops.
+ const assistantIdsKey = useAssistantState(({ thread }) =>
+ thread.messages
+ .filter((m) => m.role === "assistant" && m.id?.startsWith("msg-"))
+ .map((m) => m.id!.replace("msg-", ""))
+ .join(",")
+ );
+ const assistantDbMessageIds = useMemo(
+ () => (assistantIdsKey ? assistantIdsKey.split(",").map(Number) : []),
+ [assistantIdsKey]
+ );
+ useBatchCommentsPreload(assistantDbMessageIds);
+
// Auto-focus editor on new chat page after mount
useEffect(() => {
if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) {
diff --git a/surfsense_web/components/contact/contact-form.tsx b/surfsense_web/components/contact/contact-form.tsx
index 6f6e9f5b4..967c1c524 100644
--- a/surfsense_web/components/contact/contact-form.tsx
+++ b/surfsense_web/components/contact/contact-form.tsx
@@ -23,12 +23,12 @@ export function ContactFormGridWithDetails() {
We'd love to hear from you!
- Schedule a meeting with our Head of Product, Eric Lammertsma, or send us an email.
+ Schedule a meeting with us, or send us an email.
- eric@surfsense.com
+ rohan@surfsense.com
diff --git a/surfsense_web/components/homepage/hero-section.tsx b/surfsense_web/components/homepage/hero-section.tsx
index 00226517c..a1aa5ac4a 100644
--- a/surfsense_web/components/homepage/hero-section.tsx
+++ b/surfsense_web/components/homepage/hero-section.tsx
@@ -96,12 +96,9 @@ export function HeroSection() {
)}
- {/* // TODO:aCTUAL DESCRITION */}
-
- Connect any AI to your documents, Drive, Notion and more,
-
-
- then chat with it, generate podcasts and reports, or even invite your team.
+
+ Connect any LLM to your internal knowledge sources and chat with it in real time alongside
+ your team.
diff --git a/surfsense_web/components/homepage/navbar.tsx b/surfsense_web/components/homepage/navbar.tsx
index 2b0d60546..ddf43e7eb 100644
--- a/surfsense_web/components/homepage/navbar.tsx
+++ b/surfsense_web/components/homepage/navbar.tsx
@@ -4,7 +4,6 @@ import {
IconBrandGithub,
IconBrandReddit,
IconMenu2,
- IconSpeakerphone,
IconX,
} from "@tabler/icons-react";
import { AnimatePresence, motion } from "motion/react";
@@ -13,7 +12,6 @@ import { useEffect, useState } from "react";
import { SignInButton } from "@/components/auth/sign-in-button";
import { Logo } from "@/components/Logo";
import { ThemeTogglerComponent } from "@/components/theme/theme-toggle";
-import { useAnnouncements } from "@/hooks/use-announcements";
import { useGithubStars } from "@/hooks/use-github-stars";
import { cn } from "@/lib/utils";
@@ -49,11 +47,7 @@ export const Navbar = () => {
const DesktopNav = ({ navItems, isScrolled }: any) => {
const [hovered, setHovered] = useState(null);
- const [mounted, setMounted] = useState(false);
const { compactFormat: githubStars, loading: loadingGithubStars } = useGithubStars();
- const { unreadCount } = useAnnouncements();
-
- useEffect(() => setMounted(true), []);
return (
{
@@ -124,17 +118,6 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
)}
-
-
- {mounted && unreadCount > 0 && (
-
- {unreadCount > 99 ? "99+" : unreadCount}
-
- )}
-
@@ -144,11 +127,7 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
const MobileNav = ({ navItems, isScrolled }: any) => {
const [open, setOpen] = useState(false);
- const [mounted, setMounted] = useState(false);
const { compactFormat: githubStars, loading: loadingGithubStars } = useGithubStars();
- const { unreadCount } = useAnnouncements();
-
- useEffect(() => setMounted(true), []);
return (
{
)}
-
-
- {mounted && unreadCount > 0 && (
-
- {unreadCount > 99 ? "99+" : unreadCount}
-
- )}
-
diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx
index 7528aeb0b..ce7b06da6 100644
--- a/surfsense_web/components/pricing/pricing-section.tsx
+++ b/surfsense_web/components/pricing/pricing-section.tsx
@@ -8,43 +8,34 @@ const demoPlans = [
price: "0",
yearlyPrice: "0",
period: "",
- billingText: "Includes 30 day PRO trial",
+ billingText: "",
features: [
- "Open source on GitHub",
+ "Self Hostable",
"Upload and chat with 300+ pages of content",
- "Connects with 8 popular sources, like Drive and Notion",
- "Includes limited access to ChatGPT, Claude, and DeepSeek models",
- "Supports 100+ more LLMs, including Gemini, Llama and many more",
- "50+ File extensions supported",
- "Generate podcasts in seconds",
- "Cross-Browser Extension for dynamic webpages including authenticated content",
+ "Includes access to ChatGPT text and audio models",
+ "Realtime Collaborative Group Chats with teammates",
"Community support on Discord",
],
- description: "Powerful features with some limitations",
+ description: "",
buttonText: "Get Started",
- href: "/",
+ href: "/login",
isPopular: false,
},
{
name: "PRO",
- price: "10",
- yearlyPrice: "10",
- period: "user / month",
- billingText: "billed annually",
+ price: "0",
+ yearlyPrice: "0",
+ period: "",
+ billingText: "Free during beta",
features: [
"Everything in Free",
- "Upload and chat with 5,000+ pages of content per user",
- "Connects with 15+ external sources, like Slack and Airtable",
- "Includes extended access to ChatGPT, Claude, and DeepSeek models",
- "Collaboration and commenting features",
- "Shared BYOK (Bring Your Own Key)",
- "Team and role management",
- "Planned: Centralized billing",
- "Priority support",
+ "Includes 6000+ pages of content",
+ "Access to more models and providers",
+ "Priority support on Discord",
],
- description: "The AI knowledge base for individuals and teams",
- buttonText: "Upgrade",
- href: "/contact",
+ description: "",
+ buttonText: "Get Started",
+ href: "/login",
isPopular: true,
},
{
@@ -55,12 +46,9 @@ const demoPlans = [
billingText: "",
features: [
"Everything in Pro",
- "Connect and chat with virtually unlimited pages of content",
- "Limit models and/or providers",
"On-prem or VPC deployment",
- "Planned: Audit logs and compliance",
- "Planned: SSO, OIDC & SAML",
- "Planned: Role-based access control (RBAC)",
+ "Audit logs and compliance",
+ "SSO, OIDC & SAML",
"White-glove setup and deployment",
"Monthly managed updates and maintenance",
"SLA commitments",
diff --git a/surfsense_web/components/sources/DocumentUploadTab.tsx b/surfsense_web/components/sources/DocumentUploadTab.tsx
index caea98890..cae78f7b7 100644
--- a/surfsense_web/components/sources/DocumentUploadTab.tsx
+++ b/surfsense_web/components/sources/DocumentUploadTab.tsx
@@ -111,8 +111,8 @@ const FILE_TYPE_CONFIG: Record> = {
const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5";
-// Upload limits
-const MAX_FILES = 10;
+// Upload limits — files are sent in batches of 5 to avoid proxy timeouts
+const MAX_FILES = 50;
const MAX_TOTAL_SIZE_MB = 200;
const MAX_TOTAL_SIZE_BYTES = MAX_TOTAL_SIZE_MB * 1024 * 1024;
diff --git a/surfsense_web/contracts/types/chat-comments.types.ts b/surfsense_web/contracts/types/chat-comments.types.ts
index 46e064a4e..cdeca0a44 100644
--- a/surfsense_web/contracts/types/chat-comments.types.ts
+++ b/surfsense_web/contracts/types/chat-comments.types.ts
@@ -82,6 +82,22 @@ export const getCommentsResponse = z.object({
total_count: z.number(),
});
+/**
+ * Batch-fetch comments for multiple messages
+ */
+export const getBatchCommentsRequest = z.object({
+ message_ids: z.array(z.number()).min(1).max(200),
+});
+
+export const commentListResponse = z.object({
+ comments: z.array(comment),
+ total_count: z.number(),
+});
+
+export const getBatchCommentsResponse = z.object({
+ comments_by_message: z.record(z.string(), commentListResponse),
+});
+
/**
* Create comment
*/
@@ -145,6 +161,8 @@ export type MentionComment = z.infer;
export type Mention = z.infer;
export type GetCommentsRequest = z.infer;
export type GetCommentsResponse = z.infer;
+export type GetBatchCommentsRequest = z.infer;
+export type GetBatchCommentsResponse = z.infer;
export type CreateCommentRequest = z.infer;
export type CreateCommentResponse = z.infer;
export type CreateReplyRequest = z.infer;
diff --git a/surfsense_web/hooks/use-comments.ts b/surfsense_web/hooks/use-comments.ts
index 4f027d67c..c02f9fe16 100644
--- a/surfsense_web/hooks/use-comments.ts
+++ b/surfsense_web/hooks/use-comments.ts
@@ -1,4 +1,5 @@
-import { useQuery } from "@tanstack/react-query";
+import { useQuery, useQueryClient } from "@tanstack/react-query";
+import { useEffect, useRef } from "react";
import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
@@ -7,12 +8,109 @@ interface UseCommentsOptions {
enabled?: boolean;
}
+// ---------------------------------------------------------------------------
+// Module-level coordination: when a batch request is in-flight, individual
+// useComments queryFns piggy-back on it instead of making their own requests.
+//
+// _batchReady is a promise that resolves once the batch useEffect has had a
+// chance to set _batchInflight. Individual queryFns await this gate before
+// deciding whether to piggy-back or fetch on their own, eliminating the
+// previous race where setTimeout(0) was not enough.
+// ---------------------------------------------------------------------------
+let _batchInflight: Promise | null = null;
+let _batchTargetIds = new Set();
+let _batchReady: Promise | null = null;
+let _resolveBatchReady: (() => void) | null = null;
+
+function resetBatchGate() {
+ _batchReady = new Promise((r) => {
+ _resolveBatchReady = r;
+ });
+}
+
+// Open the initial gate immediately (no batch pending yet)
+resetBatchGate();
+_resolveBatchReady?.();
+
export function useComments({ messageId, enabled = true }: UseCommentsOptions) {
+ const queryClient = useQueryClient();
+
return useQuery({
queryKey: cacheKeys.comments.byMessage(messageId),
queryFn: async () => {
+ // Wait for the batch gate so the useEffect in useBatchCommentsPreload
+ // has a chance to set _batchInflight before we decide.
+ if (_batchReady) {
+ await _batchReady;
+ }
+
+ if (_batchInflight && _batchTargetIds.has(messageId)) {
+ await _batchInflight;
+ const cached = queryClient.getQueryData(cacheKeys.comments.byMessage(messageId));
+ if (cached) return cached;
+ }
+
return chatCommentsApiService.getComments({ message_id: messageId });
},
enabled: enabled && !!messageId,
+ staleTime: 30_000,
});
}
+
+/**
+ * Batch-fetch comments for all given message IDs in a single request, then
+ * seed the per-message React Query cache so individual useComments hooks
+ * resolve from cache instead of firing their own requests.
+ */
+export function useBatchCommentsPreload(messageIds: number[]) {
+ const queryClient = useQueryClient();
+ const prevKeyRef = useRef("");
+
+ useEffect(() => {
+ if (!messageIds.length) return;
+
+ const key = messageIds
+ .slice()
+ .sort((a, b) => a - b)
+ .join(",");
+ if (key === prevKeyRef.current) return;
+ prevKeyRef.current = key;
+
+ // Open a new gate so individual queryFns wait for us
+ resetBatchGate();
+
+ _batchTargetIds = new Set(messageIds);
+ let cancelled = false;
+
+ const promise = chatCommentsApiService
+ .getBatchComments({ message_ids: messageIds })
+ .then((data) => {
+ if (cancelled) return;
+ for (const [msgIdStr, commentList] of Object.entries(data.comments_by_message)) {
+ queryClient.setQueryData(cacheKeys.comments.byMessage(Number(msgIdStr)), commentList);
+ }
+ })
+ .catch(() => {
+ // Batch failed; individual queryFns will fall through to their own fetch
+ })
+ .finally(() => {
+ if (_batchInflight === promise) {
+ _batchInflight = null;
+ _batchTargetIds = new Set();
+ }
+ });
+
+ _batchInflight = promise;
+
+ // Release the gate — individual queryFns can now check _batchInflight
+ _resolveBatchReady?.();
+
+ return () => {
+ cancelled = true;
+ if (_batchInflight === promise) {
+ _batchInflight = null;
+ _batchTargetIds = new Set();
+ }
+ };
+ }, [messageIds, queryClient]);
+}
diff --git a/surfsense_web/hooks/use-documents.ts b/surfsense_web/hooks/use-documents.ts
index 55d48c4f1..36a359696 100644
--- a/surfsense_web/hooks/use-documents.ts
+++ b/surfsense_web/hooks/use-documents.ts
@@ -1,5 +1,6 @@
"use client";
+import { useQuery } from "@tanstack/react-query";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { documentsApiService } from "@/lib/apis/documents-api.service";
@@ -183,56 +184,47 @@ export function useDocuments(
[]
);
- // EFFECT 1: Load ALL documents from API (PRIMARY source of truth)
- // No type filter — always fetches everything so typeCounts stay complete
+ // STEP 1: Load ALL documents from API (PRIMARY source of truth).
+ // Uses React Query for automatic deduplication, caching, and staleTime so
+ // multiple components mounting useDocuments(sameId) share a single request.
+ const {
+ data: apiResponse,
+ isLoading: apiLoading,
+ error: apiError,
+ } = useQuery({
+ queryKey: ["documents", "all", searchSpaceId],
+ queryFn: () =>
+ documentsApiService.getDocuments({
+ queryParams: {
+ search_space_id: searchSpaceId!,
+ page: 0,
+ page_size: -1,
+ },
+ }),
+ enabled: !!searchSpaceId,
+ staleTime: 30_000,
+ });
+
+ // Seed local state from API response (runs once per fresh fetch)
useEffect(() => {
- if (!searchSpaceId) {
- setLoading(false);
- return;
+ if (!apiResponse) return;
+ populateUserCache(apiResponse.items);
+ const docs = apiResponse.items.map(apiToDisplayDoc);
+ setAllDocuments(docs);
+ apiLoadedRef.current = true;
+ setError(null);
+ }, [apiResponse, populateUserCache, apiToDisplayDoc]);
+
+ // Propagate loading / error from React Query
+ useEffect(() => {
+ setLoading(apiLoading);
+ }, [apiLoading]);
+
+ useEffect(() => {
+ if (apiError) {
+ setError(apiError instanceof Error ? apiError : new Error("Failed to load documents"));
}
-
- // Capture validated value for async closure
- const spaceId = searchSpaceId;
-
- let mounted = true;
- apiLoadedRef.current = false;
-
- async function loadFromApi() {
- try {
- setLoading(true);
- console.log("[useDocuments] Loading from API (source of truth):", spaceId);
-
- const response = await documentsApiService.getDocuments({
- queryParams: {
- search_space_id: spaceId,
- page: 0,
- page_size: -1, // Fetch all documents (unfiltered)
- },
- });
-
- if (!mounted) return;
-
- populateUserCache(response.items);
- const docs = response.items.map(apiToDisplayDoc);
- setAllDocuments(docs);
- apiLoadedRef.current = true;
- setError(null);
- console.log("[useDocuments] API loaded", docs.length, "documents");
- } catch (err) {
- if (!mounted) return;
- console.error("[useDocuments] API load failed:", err);
- setError(err instanceof Error ? err : new Error("Failed to load documents"));
- } finally {
- if (mounted) setLoading(false);
- }
- }
-
- loadFromApi();
-
- return () => {
- mounted = false;
- };
- }, [searchSpaceId, populateUserCache, apiToDisplayDoc]);
+ }, [apiError]);
// EFFECT 2: Start Electric sync + live query for real-time updates
// No type filter — syncs and queries ALL documents; filtering is client-side
diff --git a/surfsense_web/lib/apis/chat-comments-api.service.ts b/surfsense_web/lib/apis/chat-comments-api.service.ts
index 952de7a25..f1ec7a5d9 100644
--- a/surfsense_web/lib/apis/chat-comments-api.service.ts
+++ b/surfsense_web/lib/apis/chat-comments-api.service.ts
@@ -8,8 +8,11 @@ import {
type DeleteCommentRequest,
deleteCommentRequest,
deleteCommentResponse,
+ type GetBatchCommentsRequest,
type GetCommentsRequest,
type GetMentionsRequest,
+ getBatchCommentsRequest,
+ getBatchCommentsResponse,
getCommentsRequest,
getCommentsResponse,
getMentionsRequest,
@@ -22,6 +25,22 @@ import { ValidationError } from "@/lib/error";
import { baseApiService } from "./base-api.service";
class ChatCommentsApiService {
+ /**
+ * Batch-fetch comments for multiple messages in one request
+ */
+ getBatchComments = async (request: GetBatchCommentsRequest) => {
+ const parsed = getBatchCommentsRequest.safeParse(request);
+
+ if (!parsed.success) {
+ const errorMessage = parsed.error.issues.map((issue) => issue.message).join(", ");
+ throw new ValidationError(`Invalid request: ${errorMessage}`);
+ }
+
+ return baseApiService.post("/api/v1/messages/comments/batch", getBatchCommentsResponse, {
+ body: { message_ids: parsed.data.message_ids },
+ });
+ };
+
/**
* Get comments for a message
*/
diff --git a/surfsense_web/lib/apis/documents-api.service.ts b/surfsense_web/lib/apis/documents-api.service.ts
index e3ee2bd5b..9b0d847f4 100644
--- a/surfsense_web/lib/apis/documents-api.service.ts
+++ b/surfsense_web/lib/apis/documents-api.service.ts
@@ -109,7 +109,9 @@ class DocumentsApiService {
};
/**
- * Upload document files
+ * Upload document files in batches to avoid proxy/LB timeouts.
+ * Files are split into chunks of UPLOAD_BATCH_SIZE and sent as separate
+ * requests. Results are aggregated into a single response.
*/
uploadDocument = async (request: UploadDocumentRequest) => {
const parsedRequest = uploadDocumentRequest.safeParse(request);
@@ -121,17 +123,54 @@ class DocumentsApiService {
throw new ValidationError(`Invalid request: ${errorMessage}`);
}
- // Create FormData for file upload
- const formData = new FormData();
- parsedRequest.data.files.forEach((file) => {
- formData.append("files", file);
- });
- formData.append("search_space_id", String(parsedRequest.data.search_space_id));
- formData.append("should_summarize", String(parsedRequest.data.should_summarize));
+ const { files, search_space_id, should_summarize } = parsedRequest.data;
+ const UPLOAD_BATCH_SIZE = 5;
- return baseApiService.postFormData(`/api/v1/documents/fileupload`, uploadDocumentResponse, {
- body: formData,
- });
+ const batches: File[][] = [];
+ for (let i = 0; i < files.length; i += UPLOAD_BATCH_SIZE) {
+ batches.push(files.slice(i, i + UPLOAD_BATCH_SIZE));
+ }
+
+ const allDocumentIds: number[] = [];
+ const allDuplicateIds: number[] = [];
+ let totalFiles = 0;
+ let pendingFiles = 0;
+ let skippedDuplicates = 0;
+
+ for (const batch of batches) {
+ const formData = new FormData();
+ batch.forEach((file) => formData.append("files", file));
+ formData.append("search_space_id", String(search_space_id));
+ formData.append("should_summarize", String(should_summarize));
+
+ const controller = new AbortController();
+ const timeoutId = setTimeout(() => controller.abort(), 120_000);
+
+ try {
+ const result = await baseApiService.postFormData(
+ `/api/v1/documents/fileupload`,
+ uploadDocumentResponse,
+ { body: formData, signal: controller.signal }
+ );
+
+ allDocumentIds.push(...(result.document_ids ?? []));
+ allDuplicateIds.push(...(result.duplicate_document_ids ?? []));
+ totalFiles += result.total_files ?? batch.length;
+ pendingFiles += result.pending_files ?? 0;
+ skippedDuplicates += result.skipped_duplicates ?? 0;
+ } finally {
+ clearTimeout(timeoutId);
+ }
+ }
+
+ return {
+ message: "Files uploaded for processing" as const,
+ document_ids: allDocumentIds,
+ duplicate_document_ids: allDuplicateIds,
+ total_files: totalFiles,
+ pending_files: pendingFiles,
+ skipped_duplicates: skippedDuplicates,
+ };
};
/**
diff --git a/surfsense_web/lib/query-client/client.ts b/surfsense_web/lib/query-client/client.ts
index 6c7b9ded3..0dcc2ef03 100644
--- a/surfsense_web/lib/query-client/client.ts
+++ b/surfsense_web/lib/query-client/client.ts
@@ -1,3 +1,10 @@
import { QueryClient } from "@tanstack/react-query";
-export const queryClient = new QueryClient();
+export const queryClient = new QueryClient({
+ defaultOptions: {
+ queries: {
+ staleTime: 30_000,
+ refetchOnWindowFocus: false,
+ },
+ },
+});
diff --git a/surfsense_web/lib/source.ts b/surfsense_web/lib/source.ts
index 32a52c761..162cca57a 100644
--- a/surfsense_web/lib/source.ts
+++ b/surfsense_web/lib/source.ts
@@ -1,13 +1,12 @@
import { loader } from "fumadocs-core/source";
-import { docs } from "@/.source/server";
import { icons } from "lucide-react";
import { createElement } from "react";
+import { docs } from "@/.source/server";
export const source = loader({
baseUrl: "/docs",
source: docs.toFumadocsSource(),
icon(icon) {
- if (icon && icon in icons)
- return createElement(icons[icon as keyof typeof icons]);
+ if (icon && icon in icons) return createElement(icons[icon as keyof typeof icons]);
},
});