From 37a0bd453322b4bf052d855cd6e175e94ca2d888 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Fri, 27 Feb 2026 15:31:31 -0800 Subject: [PATCH 01/10] feat: update contact information and meeting links across components - Changed meeting link from Calendly to Cal.com for scheduling. - Updated email contact from eric@surfsense.com to rohan@surfsense.com in multiple components. - Revised text in the contact form to reflect the new scheduling options. --- .../[search_space_id]/more-pages/page.tsx | 6 +-- .../components/contact/contact-form.tsx | 8 ++-- .../components/pricing/pricing-section.tsx | 46 +++++++------------ 3 files changed, 24 insertions(+), 36 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx index f2e9fb731..77ca38c38 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/more-pages/page.tsx @@ -181,7 +181,7 @@ export default function MorePagesPage() {
- eric@surfsense.com + rohan@surfsense.com diff --git a/surfsense_web/components/contact/contact-form.tsx b/surfsense_web/components/contact/contact-form.tsx index 6f6e9f5b4..967c1c524 100644 --- a/surfsense_web/components/contact/contact-form.tsx +++ b/surfsense_web/components/contact/contact-form.tsx @@ -23,12 +23,12 @@ export function ContactFormGridWithDetails() { We'd love to hear from you!

- Schedule a meeting with our Head of Product, Eric Lammertsma, or send us an email. + Schedule a meeting with us, or send us an email.

- eric@surfsense.com + rohan@surfsense.com
diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 7528aeb0b..553fa4e7f 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -8,43 +8,34 @@ const demoPlans = [ price: "0", yearlyPrice: "0", period: "", - billingText: "Includes 30 day PRO trial", + billingText: "", features: [ - "Open source on GitHub", + "Self Hostable", "Upload and chat with 300+ pages of content", - "Connects with 8 popular sources, like Drive and Notion", - "Includes limited access to ChatGPT, Claude, and DeepSeek models", - "Supports 100+ more LLMs, including Gemini, Llama and many more", - "50+ File extensions supported", - "Generate podcasts in seconds", - "Cross-Browser Extension for dynamic webpages including authenticated content", + "Includes access to ChatGPT text and audio models", + "Realtime Collaborative Group Chats with teammates", "Community support on Discord", ], - description: "Powerful features with some limitations", + description: "", buttonText: "Get Started", href: "/", isPopular: false, }, { name: "PRO", - price: "10", - yearlyPrice: "10", - period: "user / month", - billingText: "billed annually", + price: "0", + yearlyPrice: "0", + period: "", + billingText: "Free during beta", features: [ "Everything in Free", - "Upload and chat with 5,000+ pages of content per user", - "Connects with 15+ external sources, like Slack and Airtable", - "Includes extended access to ChatGPT, Claude, and DeepSeek models", - "Collaboration and commenting features", - "Shared BYOK (Bring Your Own Key)", - "Team and role management", - "Planned: Centralized billing", - "Priority support", + "Includes 6000+ pages of content", + "Access to more models and providers", + "Priority support on Discord", ], - description: "The AI knowledge base for individuals and teams", - buttonText: "Upgrade", - href: "/contact", + description: "", + buttonText: "Get Started", + href: "/", isPopular: true, }, { @@ -55,12 +46,9 @@ const demoPlans = [ billingText: "", features: [ "Everything in Pro", - "Connect and chat with virtually unlimited pages of content", - "Limit models and/or providers", "On-prem or VPC deployment", - "Planned: Audit logs and compliance", - "Planned: SSO, OIDC & SAML", - "Planned: Role-based access control (RBAC)", + "Audit logs and compliance", + "SSO, OIDC & SAML", "White-glove setup and deployment", "Monthly managed updates and maintenance", "SLA commitments", From fa51ec42c6b68dfdde568096c27c49a29665d051 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Fri, 27 Feb 2026 15:45:48 -0800 Subject: [PATCH 02/10] chore: removed announcements from navbar and rewrote tagline --- surfsense_web/app/globals.css | 1 - .../components/homepage/hero-section.tsx | 9 ++---- surfsense_web/components/homepage/navbar.tsx | 32 ------------------- .../components/pricing/pricing-section.tsx | 4 +-- surfsense_web/lib/source.ts | 5 ++- 5 files changed, 7 insertions(+), 44 deletions(-) diff --git a/surfsense_web/app/globals.css b/surfsense_web/app/globals.css index 11d7d7a94..c192a27be 100644 --- a/surfsense_web/app/globals.css +++ b/surfsense_web/app/globals.css @@ -235,4 +235,3 @@ button { @source '../node_modules/streamdown/dist/*.js'; @source '../node_modules/@streamdown/code/dist/*.js'; @source '../node_modules/@streamdown/math/dist/*.js'; - diff --git a/surfsense_web/components/homepage/hero-section.tsx b/surfsense_web/components/homepage/hero-section.tsx index 00226517c..a1aa5ac4a 100644 --- a/surfsense_web/components/homepage/hero-section.tsx +++ b/surfsense_web/components/homepage/hero-section.tsx @@ -96,12 +96,9 @@ export function HeroSection() { )} - {/* // TODO:aCTUAL DESCRITION */} -

- Connect any AI to your documents, Drive, Notion and more, -

-

- then chat with it, generate podcasts and reports, or even invite your team. +

+ Connect any LLM to your internal knowledge sources and chat with it in real time alongside + your team.

diff --git a/surfsense_web/components/homepage/navbar.tsx b/surfsense_web/components/homepage/navbar.tsx index 2b0d60546..ddf43e7eb 100644 --- a/surfsense_web/components/homepage/navbar.tsx +++ b/surfsense_web/components/homepage/navbar.tsx @@ -4,7 +4,6 @@ import { IconBrandGithub, IconBrandReddit, IconMenu2, - IconSpeakerphone, IconX, } from "@tabler/icons-react"; import { AnimatePresence, motion } from "motion/react"; @@ -13,7 +12,6 @@ import { useEffect, useState } from "react"; import { SignInButton } from "@/components/auth/sign-in-button"; import { Logo } from "@/components/Logo"; import { ThemeTogglerComponent } from "@/components/theme/theme-toggle"; -import { useAnnouncements } from "@/hooks/use-announcements"; import { useGithubStars } from "@/hooks/use-github-stars"; import { cn } from "@/lib/utils"; @@ -49,11 +47,7 @@ export const Navbar = () => { const DesktopNav = ({ navItems, isScrolled }: any) => { const [hovered, setHovered] = useState(null); - const [mounted, setMounted] = useState(false); const { compactFormat: githubStars, loading: loadingGithubStars } = useGithubStars(); - const { unreadCount } = useAnnouncements(); - - useEffect(() => setMounted(true), []); return ( { @@ -124,17 +118,6 @@ const DesktopNav = ({ navItems, isScrolled }: any) => { )} - - - {mounted && unreadCount > 0 && ( - - {unreadCount > 99 ? "99+" : unreadCount} - - )} -
@@ -144,11 +127,7 @@ const DesktopNav = ({ navItems, isScrolled }: any) => { const MobileNav = ({ navItems, isScrolled }: any) => { const [open, setOpen] = useState(false); - const [mounted, setMounted] = useState(false); const { compactFormat: githubStars, loading: loadingGithubStars } = useGithubStars(); - const { unreadCount } = useAnnouncements(); - - useEffect(() => setMounted(true), []); return ( { )} - - - {mounted && unreadCount > 0 && ( - - {unreadCount > 99 ? "99+" : unreadCount} - - )} - diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 553fa4e7f..ce7b06da6 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -18,7 +18,7 @@ const demoPlans = [ ], description: "", buttonText: "Get Started", - href: "/", + href: "/login", isPopular: false, }, { @@ -35,7 +35,7 @@ const demoPlans = [ ], description: "", buttonText: "Get Started", - href: "/", + href: "/login", isPopular: true, }, { diff --git a/surfsense_web/lib/source.ts b/surfsense_web/lib/source.ts index 32a52c761..162cca57a 100644 --- a/surfsense_web/lib/source.ts +++ b/surfsense_web/lib/source.ts @@ -1,13 +1,12 @@ import { loader } from "fumadocs-core/source"; -import { docs } from "@/.source/server"; import { icons } from "lucide-react"; import { createElement } from "react"; +import { docs } from "@/.source/server"; export const source = loader({ baseUrl: "/docs", source: docs.toFumadocsSource(), icon(icon) { - if (icon && icon in icons) - return createElement(icons[icon as keyof typeof icons]); + if (icon && icon in icons) return createElement(icons[icon as keyof typeof icons]); }, }); From 664c43ca1382e48119ec8e10bacad9bdbafc8359 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Fri, 27 Feb 2026 16:32:30 -0800 Subject: [PATCH 03/10] feat: add performance logging middleware and enhance performance tracking across services - Introduced RequestPerfMiddleware to log request performance metrics, including slow request thresholds. - Updated various services and retrievers to utilize the new performance logging utility for better tracking of execution times. - Enhanced existing methods with detailed performance logs for operations such as embedding, searching, and indexing. - Removed deprecated logging setup in stream_new_chat and replaced it with the new performance logger. --- .../app/agents/new_chat/chat_deepagent.py | 3 +- .../agents/new_chat/tools/knowledge_base.py | 30 ++++- surfsense_backend/app/app.py | 61 +++++++++ .../indexing_pipeline_service.py | 33 ++++- .../app/retriever/chunks_hybrid_search.py | 42 +++++- .../app/retriever/documents_hybrid_search.py | 24 ++++ .../app/services/connector_service.py | 20 +++ .../app/services/llm_router_service.py | 77 ++++++++++- .../app/tasks/chat/stream_new_chat.py | 11 +- surfsense_backend/app/utils/perf.py | 122 ++++++++++++++++++ .../constants/connector-constants.ts | 43 +++--- 11 files changed, 430 insertions(+), 36 deletions(-) create mode 100644 surfsense_backend/app/utils/perf.py diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 3843b1687..af0d6bdc5 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -28,8 +28,9 @@ from app.agents.new_chat.system_prompt import ( from app.agents.new_chat.tools.registry import build_tools_async from app.db import ChatVisibility from app.services.connector_service import ConnectorService +from app.utils.perf import get_perf_logger -_perf_log = logging.getLogger("surfsense.perf") +_perf_log = get_perf_logger() # ============================================================================= # Connector Type Mapping diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index 6989a1aa2..19f21bbc6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -10,6 +10,7 @@ This module provides: import asyncio import json +import time from datetime import datetime from typing import Any @@ -19,6 +20,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import async_session_maker from app.services.connector_service import ConnectorService +from app.utils.perf import get_perf_logger # ============================================================================= # Connector Constants and Normalization @@ -412,6 +414,9 @@ async def search_knowledge_base_async( Returns: Formatted string with search results """ + perf = get_perf_logger() + t0 = time.perf_counter() + all_documents: list[dict[str, Any]] = [] # Resolve date range (default last 2 years) @@ -423,6 +428,10 @@ async def search_knowledge_base_async( ) connectors = _normalize_connectors(connectors_to_search, available_connectors) + perf.info( + "[kb_search] searching %d connectors: %s (space=%d, top_k=%d)", + len(connectors), connectors[:5], search_space_id, top_k, + ) connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { "YOUTUBE_VIDEO": ("search_youtube", True, True, {}), @@ -492,20 +501,32 @@ async def search_knowledge_base_async( try: # Use isolated session per connector. Shared AsyncSession cannot safely # run concurrent DB operations. + t_conn = time.perf_counter() async with semaphore, async_session_maker() as isolated_session: isolated_connector_service = ConnectorService( isolated_session, search_space_id ) connector_method = getattr(isolated_connector_service, method_name) _, chunks = await connector_method(**kwargs) + perf.info( + "[kb_search] connector=%s results=%d in %.3fs", + connector, len(chunks), time.perf_counter() - t_conn, + ) return chunks except Exception as e: - print(f"Error searching connector {connector}: {e}") + perf.warning( + "[kb_search] connector=%s FAILED in %.3fs: %s", + connector, time.perf_counter() - t_conn, e, + ) return [] + t_gather = time.perf_counter() connector_results = await asyncio.gather( *[_search_one_connector(connector) for connector in connectors] ) + perf.info( + "[kb_search] all connectors gathered in %.3fs", time.perf_counter() - t_gather, + ) for chunks in connector_results: all_documents.extend(chunks) @@ -552,7 +573,12 @@ async def search_knowledge_base_async( deduplicated.append(doc) output_budget = _compute_tool_output_budget(max_input_tokens) - return format_documents_for_context(deduplicated, max_chars=output_budget) + result = format_documents_for_context(deduplicated, max_chars=output_budget) + perf.info( + "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d", + time.perf_counter() - t0, len(all_documents), len(deduplicated), len(result), search_space_id, + ) + return result def _build_connector_docstring(available_connectors: list[str] | None) -> str: diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 0a549abe5..e8843878f 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -15,6 +15,9 @@ from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware from slowapi.util import get_remote_address from sqlalchemy.ext.asyncio import AsyncSession +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request as StarletteRequest +from starlette.responses import Response as StarletteResponse from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware from app.agents.new_chat.checkpointer import ( @@ -28,6 +31,7 @@ from app.routes.auth_routes import router as auth_router from app.schemas import UserCreate, UserRead, UserUpdate from app.tasks.surfsense_docs_indexer import seed_surfsense_docs from app.users import SECRET, auth_backend, current_active_user, fastapi_users +from app.utils.perf import get_perf_logger, log_system_snapshot rate_limit_logger = logging.getLogger("surfsense.rate_limit") @@ -244,6 +248,63 @@ app = FastAPI(lifespan=lifespan) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +# --------------------------------------------------------------------------- +# Request-level performance middleware +# --------------------------------------------------------------------------- +# Logs wall-clock time, method, path, and status for every request so we can +# spot slow endpoints in production logs. + +_PERF_SLOW_REQUEST_THRESHOLD = float( + __import__("os").environ.get("PERF_SLOW_REQUEST_MS", "2000") +) + + +class RequestPerfMiddleware(BaseHTTPMiddleware): + """Middleware that logs per-request wall-clock time. + + - ALL requests are logged at DEBUG level. + - Requests exceeding PERF_SLOW_REQUEST_MS (default 2000ms) are logged at + WARNING level with a system snapshot so we can correlate slow responses + with CPU/memory usage at that moment. + """ + + async def dispatch( + self, request: StarletteRequest, call_next: RequestResponseEndpoint + ) -> StarletteResponse: + perf = get_perf_logger() + t0 = time.perf_counter() + response = await call_next(request) + elapsed_ms = (time.perf_counter() - t0) * 1000 + + path = request.url.path + method = request.method + status = response.status_code + + perf.debug( + "[request] %s %s -> %d in %.1fms", + method, + path, + status, + elapsed_ms, + ) + + if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD: + perf.warning( + "[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)", + method, + path, + status, + elapsed_ms, + _PERF_SLOW_REQUEST_THRESHOLD, + ) + log_system_snapshot("slow_request") + + return response + + +app.add_middleware(RequestPerfMiddleware) + # Add SlowAPI middleware for automatic rate limiting # Uses Starlette BaseHTTPMiddleware (not the raw ASGI variant) to avoid # corrupting StreamingResponse — SlowAPIASGIMiddleware re-sends diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index eea3d6e25..e6d7977cc 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -1,4 +1,5 @@ import contextlib +import time from datetime import UTC, datetime from sqlalchemy import delete, select @@ -44,6 +45,7 @@ from app.indexing_pipeline.pipeline_logger import ( log_retryable_llm_error, log_unexpected_error, ) +from app.utils.perf import get_perf_logger class IndexingPipelineService: @@ -58,6 +60,9 @@ class IndexingPipelineService: """ Persist new documents and detect changes, returning only those that need indexing. """ + perf = get_perf_logger() + t0 = time.perf_counter() + documents = [] seen_hashes: set[str] = set() batch_ctx = PipelineLogContext( @@ -140,11 +145,12 @@ class IndexingPipelineService: try: await self.session.commit() + perf.info( + "[indexing] prepare_for_indexing in %.3fs input=%d output=%d", + time.perf_counter() - t0, len(connector_docs), len(documents), + ) return documents except IntegrityError: - # A concurrent worker committed a document with the same content_hash - # or unique_identifier_hash between our check and our INSERT. - # The document already exists — roll back and let the next sync run handle it. log_race_condition(batch_ctx) await self.session.rollback() return [] @@ -165,26 +171,39 @@ class IndexingPipelineService: unique_id=connector_doc.unique_id, doc_id=document.id, ) + perf = get_perf_logger() + t_index = time.perf_counter() try: log_index_started(ctx) document.status = DocumentStatus.processing() await self.session.commit() + t_step = time.perf_counter() if connector_doc.should_summarize and llm is not None: content = await summarize_document( connector_doc.source_markdown, llm, connector_doc.metadata ) + perf.info( + "[indexing] summarize_document doc=%d in %.3fs", + document.id, time.perf_counter() - t_step, + ) elif connector_doc.should_summarize and connector_doc.fallback_summary: content = connector_doc.fallback_summary else: content = connector_doc.source_markdown + t_step = time.perf_counter() embedding = embed_text(content) + perf.debug( + "[indexing] embed_text (summary) doc=%d in %.3fs", + document.id, time.perf_counter() - t_step, + ) await self.session.execute( delete(Chunk).where(Chunk.document_id == document.id) ) + t_step = time.perf_counter() chunks = [ Chunk(content=text, embedding=embed_text(text)) for text in chunk_text( @@ -192,6 +211,10 @@ class IndexingPipelineService: use_code_chunker=connector_doc.should_use_code_chunker, ) ] + perf.info( + "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", + document.id, len(chunks), time.perf_counter() - t_step, + ) document.content = content document.embedding = embedding @@ -199,6 +222,10 @@ class IndexingPipelineService: document.updated_at = datetime.now(UTC) document.status = DocumentStatus.ready() await self.session.commit() + perf.info( + "[indexing] index TOTAL doc=%d chunks=%d in %.3fs", + document.id, len(chunks), time.perf_counter() - t_index, + ) log_index_success(ctx, chunk_count=len(chunks)) except RETRYABLE_LLM_ERRORS as e: diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index 9aa301386..ed3f63acc 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -1,5 +1,8 @@ +import time from datetime import datetime +from app.utils.perf import get_perf_logger + class ChucksHybridSearchRetriever: def __init__(self, db_session): @@ -38,9 +41,17 @@ class ChucksHybridSearchRetriever: from app.config import config from app.db import Chunk, Document + perf = get_perf_logger() + t0 = time.perf_counter() + # Get embedding for the query embedding_model = config.embedding_model_instance + t_embed = time.perf_counter() query_embedding = embedding_model.embed(query_text) + perf.debug( + "[chunk_search] vector_search embedding in %.3fs", + time.perf_counter() - t_embed, + ) # Build the query filtered by search space query = ( @@ -60,8 +71,13 @@ class ChucksHybridSearchRetriever: query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k) # Execute the query + t_db = time.perf_counter() result = await self.db_session.execute(query) chunks = result.scalars().all() + perf.info( + "[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d", + time.perf_counter() - t_db, len(chunks), time.perf_counter() - t0, search_space_id, + ) return chunks @@ -91,6 +107,9 @@ class ChucksHybridSearchRetriever: from app.db import Chunk, Document + perf = get_perf_logger() + t0 = time.perf_counter() + # Create tsvector and tsquery for PostgreSQL full-text search tsvector = func.to_tsvector("english", Chunk.content) tsquery = func.plainto_tsquery("english", query_text) @@ -118,6 +137,10 @@ class ChucksHybridSearchRetriever: # Execute the query result = await self.db_session.execute(query) chunks = result.scalars().all() + perf.info( + "[chunk_search] full_text_search in %.3fs results=%d space=%d", + time.perf_counter() - t0, len(chunks), search_space_id, + ) return chunks @@ -157,9 +180,17 @@ class ChucksHybridSearchRetriever: from app.config import config from app.db import Chunk, Document, DocumentType + perf = get_perf_logger() + t0 = time.perf_counter() + # Get embedding for the query embedding_model = config.embedding_model_instance + t_embed = time.perf_counter() query_embedding = embedding_model.embed(query_text) + perf.debug( + "[chunk_search] hybrid_search embedding in %.3fs", + time.perf_counter() - t_embed, + ) # RRF constants k = 60 @@ -254,9 +285,14 @@ class ChucksHybridSearchRetriever: .limit(top_k) ) - # Execute the query + # Execute the RRF query + t_rrf = time.perf_counter() result = await self.db_session.execute(final_query) chunks_with_scores = result.all() + perf.info( + "[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s", + time.perf_counter() - t_rrf, len(chunks_with_scores), search_space_id, document_type, + ) # If no results were found, return an empty list if not chunks_with_scores: @@ -354,4 +390,8 @@ class ChucksHybridSearchRetriever: ) final_docs.append(entry) + perf.info( + "[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", + time.perf_counter() - t0, len(final_docs), search_space_id, document_type, + ) return final_docs diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index 9ff104ff0..608e1c2e6 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -1,5 +1,8 @@ +import time from datetime import datetime +from app.utils.perf import get_perf_logger + class DocumentHybridSearchRetriever: def __init__(self, db_session): @@ -38,6 +41,9 @@ class DocumentHybridSearchRetriever: from app.config import config from app.db import Document + perf = get_perf_logger() + t0 = time.perf_counter() + # Get embedding for the query embedding_model = config.embedding_model_instance query_embedding = embedding_model.embed(query_text) @@ -63,6 +69,10 @@ class DocumentHybridSearchRetriever: # Execute the query result = await self.db_session.execute(query) documents = result.scalars().all() + perf.info( + "[doc_search] vector_search in %.3fs results=%d space=%d", + time.perf_counter() - t0, len(documents), search_space_id, + ) return documents @@ -92,6 +102,9 @@ class DocumentHybridSearchRetriever: from app.db import Document + perf = get_perf_logger() + t0 = time.perf_counter() + # Create tsvector and tsquery for PostgreSQL full-text search tsvector = func.to_tsvector("english", Document.content) tsquery = func.plainto_tsquery("english", query_text) @@ -118,6 +131,10 @@ class DocumentHybridSearchRetriever: # Execute the query result = await self.db_session.execute(query) documents = result.scalars().all() + perf.info( + "[doc_search] full_text_search in %.3fs results=%d space=%d", + time.perf_counter() - t0, len(documents), search_space_id, + ) return documents @@ -151,6 +168,9 @@ class DocumentHybridSearchRetriever: from app.config import config from app.db import Chunk, Document, DocumentType + perf = get_perf_logger() + t0 = time.perf_counter() + # Get embedding for the query embedding_model = config.embedding_model_instance query_embedding = embedding_model.embed(query_text) @@ -303,4 +323,8 @@ class DocumentHybridSearchRetriever: ) final_docs.append(entry) + perf.info( + "[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", + time.perf_counter() - t0, len(final_docs), search_space_id, document_type, + ) return final_docs diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index 3bd9a4421..fa91de391 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -1,4 +1,5 @@ import asyncio +import time from datetime import datetime from typing import Any from urllib.parse import urljoin @@ -18,6 +19,7 @@ from app.db import ( ) from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever +from app.utils.perf import get_perf_logger class ConnectorService: @@ -246,6 +248,9 @@ class ConnectorService: Returns: List of combined and deduplicated document results """ + perf = get_perf_logger() + t0 = time.perf_counter() + # RRF constant k = 60 @@ -259,6 +264,7 @@ class ConnectorService: # "This session is provisioning a new connection; concurrent operations are not permitted" # # So we run them sequentially. + t_chunk = time.perf_counter() chunk_results = await self.chunk_retriever.hybrid_search( query_text=query_text, top_k=retriever_top_k, @@ -267,6 +273,12 @@ class ConnectorService: start_date=start_date, end_date=end_date, ) + perf.info( + "[connector_svc] _combined_rrf chunk_retriever in %.3fs results=%d type=%s", + time.perf_counter() - t_chunk, len(chunk_results), document_type, + ) + + t_doc = time.perf_counter() doc_results = await self.document_retriever.hybrid_search( query_text=query_text, top_k=retriever_top_k, @@ -275,6 +287,10 @@ class ConnectorService: start_date=start_date, end_date=end_date, ) + perf.info( + "[connector_svc] _combined_rrf doc_retriever in %.3fs results=%d type=%s", + time.perf_counter() - t_doc, len(doc_results), document_type, + ) # Helper to extract document_id from our doc-grouped result def _doc_id(item: dict[str, Any]) -> int | None: @@ -335,6 +351,10 @@ class ConnectorService: result["chunks"] = doc_data[did]["chunks"] combined_results.append(result) + perf.info( + "[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d", + time.perf_counter() - t0, len(combined_results), document_type, search_space_id, + ) return combined_results def _get_doc_url(self, metadata: dict[str, Any]) -> str: diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 2e517f0ba..e9b84c5cd 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -13,6 +13,7 @@ synchronous ChatLiteLLM-like interface and async methods. import logging import re +import time from typing import Any from langchain_core.callbacks import CallbackManagerForLLMRun @@ -26,6 +27,8 @@ from litellm.exceptions import ( ContextWindowExceededError, ) +from app.utils.perf import get_perf_logger + logger = logging.getLogger(__name__) _CONTEXT_OVERFLOW_PATTERNS = re.compile( @@ -410,6 +413,10 @@ class ChatLiteLLMRouter(BaseChatModel): if not self._router: raise ValueError("Router not initialized") + perf = get_perf_logger() + t0 = time.perf_counter() + msg_count = len(messages) + # Convert LangChain messages to OpenAI format formatted_messages = self._convert_messages(messages) @@ -428,12 +435,28 @@ class ChatLiteLLMRouter(BaseChatModel): **call_kwargs, ) except ContextWindowExceededError as e: + perf.warning( + "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", + msg_count, time.perf_counter() - t0, + ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): + perf.warning( + "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", + msg_count, time.perf_counter() - t0, + ) raise ContextOverflowError(str(e)) from e raise + elapsed = time.perf_counter() - t0 + perf.info( + "[llm_router] _generate completed msgs=%d tools=%d in %.3fs", + msg_count, + len(self._bound_tools) if self._bound_tools else 0, + elapsed, + ) + # Convert response to ChatResult with potential tool calls message = self._convert_response_to_message(response.choices[0].message) generation = ChatGeneration(message=message) @@ -453,6 +476,10 @@ class ChatLiteLLMRouter(BaseChatModel): if not self._router: raise ValueError("Router not initialized") + perf = get_perf_logger() + t0 = time.perf_counter() + msg_count = len(messages) + # Convert LangChain messages to OpenAI format formatted_messages = self._convert_messages(messages) @@ -471,12 +498,28 @@ class ChatLiteLLMRouter(BaseChatModel): **call_kwargs, ) except ContextWindowExceededError as e: + perf.warning( + "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", + msg_count, time.perf_counter() - t0, + ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): + perf.warning( + "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", + msg_count, time.perf_counter() - t0, + ) raise ContextOverflowError(str(e)) from e raise + elapsed = time.perf_counter() - t0 + perf.info( + "[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs", + msg_count, + len(self._bound_tools) if self._bound_tools else 0, + elapsed, + ) + # Convert response to ChatResult with potential tool calls message = self._convert_response_to_message(response.choices[0].message) generation = ChatGeneration(message=message) @@ -541,6 +584,10 @@ class ChatLiteLLMRouter(BaseChatModel): if not self._router: raise ValueError("Router not initialized") + perf = get_perf_logger() + t0 = time.perf_counter() + msg_count = len(messages) + formatted_messages = self._convert_messages(messages) # Add tools if bound @@ -559,20 +606,48 @@ class ChatLiteLLMRouter(BaseChatModel): **call_kwargs, ) except ContextWindowExceededError as e: + perf.warning( + "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", + msg_count, time.perf_counter() - t0, + ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): + perf.warning( + "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", + msg_count, time.perf_counter() - t0, + ) raise ContextOverflowError(str(e)) from e raise - # Yield chunks asynchronously + t_first_chunk = time.perf_counter() + perf.info( + "[llm_router] _astream connection established msgs=%d in %.3fs", + msg_count, t_first_chunk - t0, + ) + + chunk_count = 0 + first_chunk_logged = False async for chunk in response: if hasattr(chunk, "choices") and chunk.choices: delta = chunk.choices[0].delta chunk_msg = self._convert_delta_to_chunk(delta) if chunk_msg: + chunk_count += 1 + if not first_chunk_logged: + perf.info( + "[llm_router] _astream first chunk in %.3fs (total %.3fs from start)", + time.perf_counter() - t_first_chunk, + time.perf_counter() - t0, + ) + first_chunk_logged = True yield ChatGenerationChunk(message=chunk_msg) + perf.info( + "[llm_router] _astream completed chunks=%d total=%.3fs", + chunk_count, time.perf_counter() - t0, + ) + def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]: """Convert LangChain messages to OpenAI format.""" from langchain_core.messages import ( diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index ddadbc48b..9e91a8735 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -56,14 +56,9 @@ from app.services.chat_session_state_service import ( from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.utils.content_utils import bootstrap_history_from_db +from app.utils.perf import get_perf_logger, log_system_snapshot -_perf_log = logging.getLogger("surfsense.perf") -_perf_log.setLevel(logging.DEBUG) -if not _perf_log.handlers: - _h = logging.StreamHandler() - _h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s")) - _perf_log.addHandler(_h) - _perf_log.propagate = False +_perf_log = get_perf_logger() _background_tasks: set[asyncio.Task] = set() @@ -1050,6 +1045,7 @@ async def stream_new_chat( streaming_service = VercelStreamingService() stream_result = StreamResult() _t_total = time.perf_counter() + log_system_snapshot("stream_new_chat_START") try: # Mark AI as responding to this user for live collaboration @@ -1382,6 +1378,7 @@ async def stream_new_chat( time.perf_counter() - _t_stream_start, chat_id, ) + log_system_snapshot("stream_new_chat_END") if stream_result.is_interrupted: yield streaming_service.format_finish_step() diff --git a/surfsense_backend/app/utils/perf.py b/surfsense_backend/app/utils/perf.py new file mode 100644 index 000000000..301498048 --- /dev/null +++ b/surfsense_backend/app/utils/perf.py @@ -0,0 +1,122 @@ +""" +Centralized performance monitoring for SurfSense backend. + +Provides: +- A shared [PERF] logger used across all modules +- perf_timer context manager for timing code blocks +- perf_async_timer for async code blocks +- system_snapshot() for CPU/memory profiling +- RequestPerfMiddleware for per-request timing +""" + +import logging +import os +import time +from contextlib import asynccontextmanager, contextmanager +from typing import Any + +_perf_log: logging.Logger | None = None + + +def get_perf_logger() -> logging.Logger: + """Return the singleton [PERF] logger, creating it once on first call.""" + global _perf_log + if _perf_log is None: + _perf_log = logging.getLogger("surfsense.perf") + _perf_log.setLevel(logging.DEBUG) + if not _perf_log.handlers: + h = logging.StreamHandler() + h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s")) + _perf_log.addHandler(h) + _perf_log.propagate = False + return _perf_log + + +@contextmanager +def perf_timer(label: str, *, extra: dict[str, Any] | None = None): + """Synchronous context manager that logs elapsed wall-clock time. + + Usage: + with perf_timer("[my_func] heavy computation"): + ... + """ + log = get_perf_logger() + t0 = time.perf_counter() + yield + elapsed = time.perf_counter() - t0 + suffix = "" + if extra: + suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items()) + log.info("%s in %.3fs%s", label, elapsed, suffix) + + +@asynccontextmanager +async def perf_async_timer(label: str, *, extra: dict[str, Any] | None = None): + """Async context manager that logs elapsed wall-clock time. + + Usage: + async with perf_async_timer("[search] vector search"): + ... + """ + log = get_perf_logger() + t0 = time.perf_counter() + yield + elapsed = time.perf_counter() - t0 + suffix = "" + if extra: + suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items()) + log.info("%s in %.3fs%s", label, elapsed, suffix) + + +def system_snapshot() -> dict[str, Any]: + """Capture a lightweight CPU + memory snapshot of the current process. + + Returns a dict with: + - rss_mb: Resident Set Size in MB + - cpu_percent: CPU usage % since last call (per-process) + - threads: number of active threads + - open_fds: number of open file descriptors (Linux only) + - asyncio_tasks: number of asyncio tasks currently alive + """ + import asyncio + + snapshot: dict[str, Any] = {} + try: + import psutil + + proc = psutil.Process(os.getpid()) + mem = proc.memory_info() + snapshot["rss_mb"] = round(mem.rss / 1024 / 1024, 1) + snapshot["cpu_percent"] = proc.cpu_percent(interval=None) + snapshot["threads"] = proc.num_threads() + try: + snapshot["open_fds"] = proc.num_fds() + except AttributeError: + snapshot["open_fds"] = -1 + except ImportError: + snapshot["rss_mb"] = -1 + snapshot["cpu_percent"] = -1 + snapshot["threads"] = -1 + snapshot["open_fds"] = -1 + + try: + all_tasks = asyncio.all_tasks() + snapshot["asyncio_tasks"] = len(all_tasks) + except RuntimeError: + snapshot["asyncio_tasks"] = -1 + + return snapshot + + +def log_system_snapshot(label: str = "system_snapshot") -> None: + """Capture and log a system snapshot.""" + snap = system_snapshot() + get_perf_logger().info( + "[%s] rss=%.1fMB cpu=%.1f%% threads=%d fds=%d asyncio_tasks=%d", + label, + snap["rss_mb"], + snap["cpu_percent"], + snap["threads"], + snap["open_fds"], + snap["asyncio_tasks"], + ) diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index a3e8ae272..5deee8360 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -2,27 +2,28 @@ import { EnumConnectorName } from "@/contracts/enums/connector"; // OAuth Connectors (Quick Connect) export const OAUTH_CONNECTORS = [ - { - id: "google-drive-connector", - title: "Google Drive", - description: "Search your Drive files", - connectorType: EnumConnectorName.GOOGLE_DRIVE_CONNECTOR, - authEndpoint: "/api/v1/auth/google/drive/connector/add/", - }, - { - id: "google-gmail-connector", - title: "Gmail", - description: "Search through your emails", - connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR, - authEndpoint: "/api/v1/auth/google/gmail/connector/add/", - }, - { - id: "google-calendar-connector", - title: "Google Calendar", - description: "Search through your events", - connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR, - authEndpoint: "/api/v1/auth/google/calendar/connector/add/", - }, + // // Uncomment for managed Google Connections + // { + // id: "google-drive-connector", + // title: "Google Drive", + // description: "Search your Drive files", + // connectorType: EnumConnectorName.GOOGLE_DRIVE_CONNECTOR, + // authEndpoint: "/api/v1/auth/google/drive/connector/add/", + // }, + // { + // id: "google-gmail-connector", + // title: "Gmail", + // description: "Search through your emails", + // connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR, + // authEndpoint: "/api/v1/auth/google/gmail/connector/add/", + // }, + // { + // id: "google-calendar-connector", + // title: "Google Calendar", + // description: "Search through your events", + // connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR, + // authEndpoint: "/api/v1/auth/google/calendar/connector/add/", + // }, { id: "airtable-connector", title: "Airtable", From 0e723a5b8b26d2c05a5553cfb372ed9f7b6a1f9b Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Fri, 27 Feb 2026 17:19:25 -0800 Subject: [PATCH 04/10] feat: perf optimizations - improved search_knowledgebase_tool - Added new endpoint to batch-fetch comments for multiple messages, reducing the number of API calls. - Introduced CommentBatchRequest and CommentBatchResponse schemas for handling batch requests and responses. - Updated chat_comments_service to validate message existence and permissions before fetching comments. - Enhanced frontend with useBatchCommentsPreload hook to optimize comment loading for assistant messages. --- .../agents/new_chat/tools/knowledge_base.py | 22 ++- .../indexing_pipeline_service.py | 18 ++- .../app/retriever/chunks_hybrid_search.py | 37 +++-- .../app/retriever/documents_hybrid_search.py | 22 ++- .../app/routes/chat_comments_routes.py | 13 ++ .../app/schemas/chat_comments.py | 12 ++ .../app/services/chat_comments_service.py | 141 ++++++++++++++++++ .../app/services/connector_service.py | 73 +++++---- .../app/services/llm_router_service.py | 24 ++- .../components/assistant-ui/thread.tsx | 17 +++ .../contracts/types/chat-comments.types.ts | 18 +++ surfsense_web/hooks/use-comments.ts | 75 +++++++++- .../lib/apis/chat-comments-api.service.ts | 19 +++ 13 files changed, 424 insertions(+), 67 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index 19f21bbc6..9394d68b4 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -430,7 +430,10 @@ async def search_knowledge_base_async( connectors = _normalize_connectors(connectors_to_search, available_connectors) perf.info( "[kb_search] searching %d connectors: %s (space=%d, top_k=%d)", - len(connectors), connectors[:5], search_space_id, top_k, + len(connectors), + connectors[:5], + search_space_id, + top_k, ) connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { @@ -510,13 +513,17 @@ async def search_knowledge_base_async( _, chunks = await connector_method(**kwargs) perf.info( "[kb_search] connector=%s results=%d in %.3fs", - connector, len(chunks), time.perf_counter() - t_conn, + connector, + len(chunks), + time.perf_counter() - t_conn, ) return chunks except Exception as e: perf.warning( "[kb_search] connector=%s FAILED in %.3fs: %s", - connector, time.perf_counter() - t_conn, e, + connector, + time.perf_counter() - t_conn, + e, ) return [] @@ -525,7 +532,8 @@ async def search_knowledge_base_async( *[_search_one_connector(connector) for connector in connectors] ) perf.info( - "[kb_search] all connectors gathered in %.3fs", time.perf_counter() - t_gather, + "[kb_search] all connectors gathered in %.3fs", + time.perf_counter() - t_gather, ) for chunks in connector_results: all_documents.extend(chunks) @@ -576,7 +584,11 @@ async def search_knowledge_base_async( result = format_documents_for_context(deduplicated, max_chars=output_budget) perf.info( "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d", - time.perf_counter() - t0, len(all_documents), len(deduplicated), len(result), search_space_id, + time.perf_counter() - t0, + len(all_documents), + len(deduplicated), + len(result), + search_space_id, ) return result diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index e6d7977cc..9460f900c 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -147,7 +147,9 @@ class IndexingPipelineService: await self.session.commit() perf.info( "[indexing] prepare_for_indexing in %.3fs input=%d output=%d", - time.perf_counter() - t0, len(connector_docs), len(documents), + time.perf_counter() - t0, + len(connector_docs), + len(documents), ) return documents except IntegrityError: @@ -185,7 +187,8 @@ class IndexingPipelineService: ) perf.info( "[indexing] summarize_document doc=%d in %.3fs", - document.id, time.perf_counter() - t_step, + document.id, + time.perf_counter() - t_step, ) elif connector_doc.should_summarize and connector_doc.fallback_summary: content = connector_doc.fallback_summary @@ -196,7 +199,8 @@ class IndexingPipelineService: embedding = embed_text(content) perf.debug( "[indexing] embed_text (summary) doc=%d in %.3fs", - document.id, time.perf_counter() - t_step, + document.id, + time.perf_counter() - t_step, ) await self.session.execute( @@ -213,7 +217,9 @@ class IndexingPipelineService: ] perf.info( "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", - document.id, len(chunks), time.perf_counter() - t_step, + document.id, + len(chunks), + time.perf_counter() - t_step, ) document.content = content @@ -224,7 +230,9 @@ class IndexingPipelineService: await self.session.commit() perf.info( "[indexing] index TOTAL doc=%d chunks=%d in %.3fs", - document.id, len(chunks), time.perf_counter() - t_index, + document.id, + len(chunks), + time.perf_counter() - t_index, ) log_index_success(ctx, chunk_count=len(chunks)) diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index ed3f63acc..38ecba96c 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -76,7 +76,10 @@ class ChucksHybridSearchRetriever: chunks = result.scalars().all() perf.info( "[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d", - time.perf_counter() - t_db, len(chunks), time.perf_counter() - t0, search_space_id, + time.perf_counter() - t_db, + len(chunks), + time.perf_counter() - t0, + search_space_id, ) return chunks @@ -139,7 +142,9 @@ class ChucksHybridSearchRetriever: chunks = result.scalars().all() perf.info( "[chunk_search] full_text_search in %.3fs results=%d space=%d", - time.perf_counter() - t0, len(chunks), search_space_id, + time.perf_counter() - t0, + len(chunks), + search_space_id, ) return chunks @@ -152,6 +157,7 @@ class ChucksHybridSearchRetriever: document_type: str | None = None, start_date: datetime | None = None, end_date: datetime | None = None, + query_embedding: list | None = None, ) -> list: """ Hybrid search that returns **documents** (not individual chunks). @@ -166,6 +172,7 @@ class ChucksHybridSearchRetriever: document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") start_date: Optional start date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at + query_embedding: Pre-computed embedding vector. If None, will be computed here. Returns: List of dictionaries containing document data and relevance scores. Each dict contains: @@ -183,14 +190,14 @@ class ChucksHybridSearchRetriever: perf = get_perf_logger() t0 = time.perf_counter() - # Get embedding for the query - embedding_model = config.embedding_model_instance - t_embed = time.perf_counter() - query_embedding = embedding_model.embed(query_text) - perf.debug( - "[chunk_search] hybrid_search embedding in %.3fs", - time.perf_counter() - t_embed, - ) + if query_embedding is None: + embedding_model = config.embedding_model_instance + t_embed = time.perf_counter() + query_embedding = embedding_model.embed(query_text) + perf.debug( + "[chunk_search] hybrid_search embedding in %.3fs", + time.perf_counter() - t_embed, + ) # RRF constants k = 60 @@ -291,7 +298,10 @@ class ChucksHybridSearchRetriever: chunks_with_scores = result.all() perf.info( "[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s", - time.perf_counter() - t_rrf, len(chunks_with_scores), search_space_id, document_type, + time.perf_counter() - t_rrf, + len(chunks_with_scores), + search_space_id, + document_type, ) # If no results were found, return an empty list @@ -392,6 +402,9 @@ class ChucksHybridSearchRetriever: perf.info( "[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", - time.perf_counter() - t0, len(final_docs), search_space_id, document_type, + time.perf_counter() - t0, + len(final_docs), + search_space_id, + document_type, ) return final_docs diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index 608e1c2e6..f4daf8e26 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -71,7 +71,9 @@ class DocumentHybridSearchRetriever: documents = result.scalars().all() perf.info( "[doc_search] vector_search in %.3fs results=%d space=%d", - time.perf_counter() - t0, len(documents), search_space_id, + time.perf_counter() - t0, + len(documents), + search_space_id, ) return documents @@ -133,7 +135,9 @@ class DocumentHybridSearchRetriever: documents = result.scalars().all() perf.info( "[doc_search] full_text_search in %.3fs results=%d space=%d", - time.perf_counter() - t0, len(documents), search_space_id, + time.perf_counter() - t0, + len(documents), + search_space_id, ) return documents @@ -146,6 +150,7 @@ class DocumentHybridSearchRetriever: document_type: str | None = None, start_date: datetime | None = None, end_date: datetime | None = None, + query_embedding: list | None = None, ) -> list: """ Hybrid search that returns **documents** (not individual chunks). @@ -160,7 +165,7 @@ class DocumentHybridSearchRetriever: document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") start_date: Optional start date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at - + query_embedding: Pre-computed embedding vector. If None, will be computed here. """ from sqlalchemy import func, select, text from sqlalchemy.orm import joinedload @@ -171,9 +176,9 @@ class DocumentHybridSearchRetriever: perf = get_perf_logger() t0 = time.perf_counter() - # Get embedding for the query - embedding_model = config.embedding_model_instance - query_embedding = embedding_model.embed(query_text) + if query_embedding is None: + embedding_model = config.embedding_model_instance + query_embedding = embedding_model.embed(query_text) # RRF constants k = 60 @@ -325,6 +330,9 @@ class DocumentHybridSearchRetriever: perf.info( "[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", - time.perf_counter() - t0, len(final_docs), search_space_id, document_type, + time.perf_counter() - t0, + len(final_docs), + search_space_id, + document_type, ) return final_docs diff --git a/surfsense_backend/app/routes/chat_comments_routes.py b/surfsense_backend/app/routes/chat_comments_routes.py index 1c21c0f4a..f5a8fd0af 100644 --- a/surfsense_backend/app/routes/chat_comments_routes.py +++ b/surfsense_backend/app/routes/chat_comments_routes.py @@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import User, get_async_session from app.schemas.chat_comments import ( + CommentBatchRequest, + CommentBatchResponse, CommentCreateRequest, CommentListResponse, CommentReplyResponse, @@ -19,6 +21,7 @@ from app.services.chat_comments_service import ( create_reply, delete_comment, get_comments_for_message, + get_comments_for_messages_batch, get_user_mentions, update_comment, ) @@ -27,6 +30,16 @@ from app.users import current_active_user router = APIRouter() +@router.post("/messages/comments/batch", response_model=CommentBatchResponse) +async def batch_list_comments( + request: CommentBatchRequest, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Batch-fetch comments for multiple messages in one request.""" + return await get_comments_for_messages_batch(session, request.message_ids, user) + + @router.get("/messages/{message_id}/comments", response_model=CommentListResponse) async def list_comments( message_id: int, diff --git a/surfsense_backend/app/schemas/chat_comments.py b/surfsense_backend/app/schemas/chat_comments.py index b87ee58a4..984e8b812 100644 --- a/surfsense_backend/app/schemas/chat_comments.py +++ b/surfsense_backend/app/schemas/chat_comments.py @@ -87,6 +87,18 @@ class CommentListResponse(BaseModel): total_count: int +class CommentBatchRequest(BaseModel): + """Request for batch-fetching comments for multiple messages.""" + + message_ids: list[int] = Field(..., min_length=1, max_length=200) + + +class CommentBatchResponse(BaseModel): + """Batch response keyed by message_id.""" + + comments_by_message: dict[int, CommentListResponse] + + # ============================================================================= # Mention Schemas # ============================================================================= diff --git a/surfsense_backend/app/services/chat_comments_service.py b/surfsense_backend/app/services/chat_comments_service.py index c9ca920f6..c2bb65aee 100644 --- a/surfsense_backend/app/services/chat_comments_service.py +++ b/surfsense_backend/app/services/chat_comments_service.py @@ -22,6 +22,7 @@ from app.db import ( ) from app.schemas.chat_comments import ( AuthorResponse, + CommentBatchResponse, CommentListResponse, CommentReplyResponse, CommentResponse, @@ -264,6 +265,146 @@ async def get_comments_for_message( ) +async def get_comments_for_messages_batch( + session: AsyncSession, + message_ids: list[int], + user: User, +) -> CommentBatchResponse: + """ + Batch-fetch comments for multiple messages in a single DB round-trip. + + Validates that all messages exist and belong to search spaces the user + can read comments in, then loads all comments with eager-loaded authors + and replies. + """ + if not message_ids: + return CommentBatchResponse(comments_by_message={}) + + unique_ids = list(set(message_ids)) + + result = await session.execute( + select(NewChatMessage) + .options(selectinload(NewChatMessage.thread)) + .filter(NewChatMessage.id.in_(unique_ids)) + ) + messages = result.scalars().all() + msg_map = {m.id: m for m in messages} + + search_space_ids = {m.thread.search_space_id for m in messages} + permissions_cache: dict[int, set] = {} + for ss_id in search_space_ids: + await check_permission( + session, + user, + ss_id, + Permission.COMMENTS_READ.value, + "You don't have permission to read comments in this search space", + ) + permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id) + + result = await session.execute( + select(ChatComment) + .options( + selectinload(ChatComment.author), + selectinload(ChatComment.replies).selectinload(ChatComment.author), + ) + .filter( + ChatComment.message_id.in_(unique_ids), + ChatComment.parent_id.is_(None), + ) + .order_by(ChatComment.created_at) + ) + top_level_comments = result.scalars().all() + + all_mentioned_uuids: set[UUID] = set() + for comment in top_level_comments: + all_mentioned_uuids.update(parse_mentions(comment.content)) + for reply in comment.replies: + all_mentioned_uuids.update(parse_mentions(reply.content)) + + user_names = await get_user_names_for_mentions(session, all_mentioned_uuids) + + comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids} + for comment in top_level_comments: + comments_by_msg.setdefault(comment.message_id, []).append(comment) + + comments_by_message: dict[int, CommentListResponse] = {} + for mid in unique_ids: + msg = msg_map.get(mid) + if msg is None: + comments_by_message[mid] = CommentListResponse(comments=[], total_count=0) + continue + + ss_id = msg.thread.search_space_id + user_perms = permissions_cache.get(ss_id, set()) + can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value) + + comment_responses = [] + for comment in comments_by_msg.get(mid, []): + author = None + if comment.author: + author = AuthorResponse( + id=comment.author.id, + display_name=comment.author.display_name, + avatar_url=comment.author.avatar_url, + email=comment.author.email, + ) + + replies = [] + for reply in sorted(comment.replies, key=lambda r: r.created_at): + reply_author = None + if reply.author: + reply_author = AuthorResponse( + id=reply.author.id, + display_name=reply.author.display_name, + avatar_url=reply.author.avatar_url, + email=reply.author.email, + ) + is_reply_author = ( + reply.author_id == user.id if reply.author_id else False + ) + replies.append( + CommentReplyResponse( + id=reply.id, + content=reply.content, + content_rendered=render_mentions(reply.content, user_names), + author=reply_author, + created_at=reply.created_at, + updated_at=reply.updated_at, + is_edited=reply.updated_at > reply.created_at, + can_edit=is_reply_author, + can_delete=is_reply_author or can_delete_any, + ) + ) + + is_comment_author = ( + comment.author_id == user.id if comment.author_id else False + ) + comment_responses.append( + CommentResponse( + id=comment.id, + message_id=comment.message_id, + content=comment.content, + content_rendered=render_mentions(comment.content, user_names), + author=author, + created_at=comment.created_at, + updated_at=comment.updated_at, + is_edited=comment.updated_at > comment.created_at, + can_edit=is_comment_author, + can_delete=is_comment_author or can_delete_any, + reply_count=len(replies), + replies=replies, + ) + ) + + comments_by_message[mid] = CommentListResponse( + comments=comment_responses, + total_count=len(comment_responses), + ) + + return CommentBatchResponse(comments_by_message=comments_by_message) + + async def create_comment( session: AsyncSession, message_id: int, diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index fa91de391..157e0bab5 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -16,6 +16,7 @@ from app.db import ( Document, SearchSourceConnector, SearchSourceConnectorType, + async_session_maker, ) from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever @@ -248,6 +249,8 @@ class ConnectorService: Returns: List of combined and deduplicated document results """ + from app.config import config + perf = get_perf_logger() t0 = time.perf_counter() @@ -257,39 +260,48 @@ class ConnectorService: # Get more results from each retriever for better fusion retriever_top_k = top_k * 2 - # IMPORTANT: - # These retrievers share the same AsyncSession. AsyncSession does not permit - # concurrent awaits that require DB IO on the same session/connection. - # Running these in parallel can raise: - # "This session is provisioning a new connection; concurrent operations are not permitted" - # - # So we run them sequentially. - t_chunk = time.perf_counter() - chunk_results = await self.chunk_retriever.hybrid_search( - query_text=query_text, - top_k=retriever_top_k, - search_space_id=search_space_id, - document_type=document_type, - start_date=start_date, - end_date=end_date, - ) + # Pre-compute the embedding once so both retrievers reuse it. + t_embed = time.perf_counter() + query_embedding = config.embedding_model_instance.embed(query_text) perf.info( - "[connector_svc] _combined_rrf chunk_retriever in %.3fs results=%d type=%s", - time.perf_counter() - t_chunk, len(chunk_results), document_type, + "[connector_svc] _combined_rrf embedding in %.3fs type=%s", + time.perf_counter() - t_embed, + document_type, ) - t_doc = time.perf_counter() - doc_results = await self.document_retriever.hybrid_search( - query_text=query_text, - top_k=retriever_top_k, - search_space_id=search_space_id, - document_type=document_type, - start_date=start_date, - end_date=end_date, + search_kwargs = { + "query_text": query_text, + "top_k": retriever_top_k, + "search_space_id": search_space_id, + "document_type": document_type, + "start_date": start_date, + "end_date": end_date, + "query_embedding": query_embedding, + } + + # Run chunk and document retrievers in parallel using separate DB sessions + # so they don't contend on a shared AsyncSession connection. + async def _run_chunk_search() -> list[dict[str, Any]]: + async with async_session_maker() as session: + retriever = ChucksHybridSearchRetriever(session) + return await retriever.hybrid_search(**search_kwargs) + + async def _run_doc_search() -> list[dict[str, Any]]: + async with async_session_maker() as session: + retriever = DocumentHybridSearchRetriever(session) + return await retriever.hybrid_search(**search_kwargs) + + t_parallel = time.perf_counter() + chunk_results, doc_results = await asyncio.gather( + _run_chunk_search(), _run_doc_search() ) perf.info( - "[connector_svc] _combined_rrf doc_retriever in %.3fs results=%d type=%s", - time.perf_counter() - t_doc, len(doc_results), document_type, + "[connector_svc] _combined_rrf parallel retrievers in %.3fs " + "chunk_results=%d doc_results=%d type=%s", + time.perf_counter() - t_parallel, + len(chunk_results), + len(doc_results), + document_type, ) # Helper to extract document_id from our doc-grouped result @@ -353,7 +365,10 @@ class ConnectorService: perf.info( "[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d", - time.perf_counter() - t0, len(combined_results), document_type, search_space_id, + time.perf_counter() - t0, + len(combined_results), + document_type, + search_space_id, ) return combined_results diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index e9b84c5cd..7839e4014 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -437,14 +437,16 @@ class ChatLiteLLMRouter(BaseChatModel): except ContextWindowExceededError as e: perf.warning( "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise @@ -500,14 +502,16 @@ class ChatLiteLLMRouter(BaseChatModel): except ContextWindowExceededError as e: perf.warning( "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise @@ -608,14 +612,16 @@ class ChatLiteLLMRouter(BaseChatModel): except ContextWindowExceededError as e: perf.warning( "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", - msg_count, time.perf_counter() - t0, + msg_count, + time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise @@ -623,7 +629,8 @@ class ChatLiteLLMRouter(BaseChatModel): t_first_chunk = time.perf_counter() perf.info( "[llm_router] _astream connection established msgs=%d in %.3fs", - msg_count, t_first_chunk - t0, + msg_count, + t_first_chunk - t0, ) chunk_count = 0 @@ -645,7 +652,8 @@ class ChatLiteLLMRouter(BaseChatModel): perf.info( "[llm_router] _astream completed chunks=%d total=%.3fs", - chunk_count, time.perf_counter() - t0, + chunk_count, + time.perf_counter() - t0, ) def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]: diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index cd0b4971c..98fa5b436 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -65,6 +65,7 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import type { Document } from "@/contracts/types/document.types"; +import { useBatchCommentsPreload } from "@/hooks/use-comments"; import { useCommentsElectric } from "@/hooks/use-comments-electric"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { cn } from "@/lib/utils"; @@ -309,6 +310,22 @@ const Composer: FC = () => { // Sync comments for the entire thread via Electric SQL (one subscription per thread) useCommentsElectric(threadId); + // Batch-prefetch comments for all assistant messages so individual useComments + // hooks never fire their own network requests (eliminates N+1 API calls). + // Return a primitive string from the selector so useSyncExternalStore can + // compare snapshots by value and avoid infinite re-render loops. + const assistantIdsKey = useAssistantState(({ thread }) => + thread.messages + .filter((m) => m.role === "assistant" && m.id?.startsWith("msg-")) + .map((m) => m.id!.replace("msg-", "")) + .join(",") + ); + const assistantDbMessageIds = useMemo( + () => (assistantIdsKey ? assistantIdsKey.split(",").map(Number) : []), + [assistantIdsKey] + ); + useBatchCommentsPreload(assistantDbMessageIds); + // Auto-focus editor on new chat page after mount useEffect(() => { if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) { diff --git a/surfsense_web/contracts/types/chat-comments.types.ts b/surfsense_web/contracts/types/chat-comments.types.ts index 46e064a4e..cdeca0a44 100644 --- a/surfsense_web/contracts/types/chat-comments.types.ts +++ b/surfsense_web/contracts/types/chat-comments.types.ts @@ -82,6 +82,22 @@ export const getCommentsResponse = z.object({ total_count: z.number(), }); +/** + * Batch-fetch comments for multiple messages + */ +export const getBatchCommentsRequest = z.object({ + message_ids: z.array(z.number()).min(1).max(200), +}); + +export const commentListResponse = z.object({ + comments: z.array(comment), + total_count: z.number(), +}); + +export const getBatchCommentsResponse = z.object({ + comments_by_message: z.record(z.string(), commentListResponse), +}); + /** * Create comment */ @@ -145,6 +161,8 @@ export type MentionComment = z.infer; export type Mention = z.infer; export type GetCommentsRequest = z.infer; export type GetCommentsResponse = z.infer; +export type GetBatchCommentsRequest = z.infer; +export type GetBatchCommentsResponse = z.infer; export type CreateCommentRequest = z.infer; export type CreateCommentResponse = z.infer; export type CreateReplyRequest = z.infer; diff --git a/surfsense_web/hooks/use-comments.ts b/surfsense_web/hooks/use-comments.ts index 4f027d67c..562f7ae02 100644 --- a/surfsense_web/hooks/use-comments.ts +++ b/surfsense_web/hooks/use-comments.ts @@ -1,4 +1,5 @@ -import { useQuery } from "@tanstack/react-query"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useEffect, useRef } from "react"; import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; @@ -7,12 +8,84 @@ interface UseCommentsOptions { enabled?: boolean; } +// --------------------------------------------------------------------------- +// Module-level coordination: when a batch request is in-flight, individual +// useComments queryFns piggy-back on it instead of making their own requests. +// --------------------------------------------------------------------------- +let _batchInflight: Promise | null = null; +let _batchTargetIds = new Set(); + export function useComments({ messageId, enabled = true }: UseCommentsOptions) { + const queryClient = useQueryClient(); + return useQuery({ queryKey: cacheKeys.comments.byMessage(messageId), queryFn: async () => { + // Yield one macro-task so the batch prefetch useEffect (which sets + // _batchInflight) has a chance to fire before we decide to fetch. + await new Promise((r) => setTimeout(r, 0)); + + if (_batchInflight && _batchTargetIds.has(messageId)) { + await _batchInflight; + const cached = queryClient.getQueryData(cacheKeys.comments.byMessage(messageId)); + if (cached) return cached; + } + return chatCommentsApiService.getComments({ message_id: messageId }); }, enabled: enabled && !!messageId, + staleTime: 30_000, }); } + +/** + * Batch-fetch comments for all given message IDs in a single request, then + * seed the per-message React Query cache so individual useComments hooks + * resolve from cache instead of firing their own requests. + */ +export function useBatchCommentsPreload(messageIds: number[]) { + const queryClient = useQueryClient(); + const prevKeyRef = useRef(""); + + useEffect(() => { + if (!messageIds.length) return; + + const key = messageIds + .slice() + .sort((a, b) => a - b) + .join(","); + if (key === prevKeyRef.current) return; + prevKeyRef.current = key; + + _batchTargetIds = new Set(messageIds); + let cancelled = false; + + const promise = chatCommentsApiService + .getBatchComments({ message_ids: messageIds }) + .then((data) => { + if (cancelled) return; + for (const [msgIdStr, commentList] of Object.entries(data.comments_by_message)) { + queryClient.setQueryData(cacheKeys.comments.byMessage(Number(msgIdStr)), commentList); + } + }) + .catch(() => { + // Batch failed; individual queryFns will fall through to their own fetch + }) + .finally(() => { + if (_batchInflight === promise) { + _batchInflight = null; + _batchTargetIds = new Set(); + } + }); + + _batchInflight = promise; + + return () => { + cancelled = true; + if (_batchInflight === promise) { + _batchInflight = null; + _batchTargetIds = new Set(); + } + }; + }, [messageIds, queryClient]); +} diff --git a/surfsense_web/lib/apis/chat-comments-api.service.ts b/surfsense_web/lib/apis/chat-comments-api.service.ts index 952de7a25..f1ec7a5d9 100644 --- a/surfsense_web/lib/apis/chat-comments-api.service.ts +++ b/surfsense_web/lib/apis/chat-comments-api.service.ts @@ -8,8 +8,11 @@ import { type DeleteCommentRequest, deleteCommentRequest, deleteCommentResponse, + type GetBatchCommentsRequest, type GetCommentsRequest, type GetMentionsRequest, + getBatchCommentsRequest, + getBatchCommentsResponse, getCommentsRequest, getCommentsResponse, getMentionsRequest, @@ -22,6 +25,22 @@ import { ValidationError } from "@/lib/error"; import { baseApiService } from "./base-api.service"; class ChatCommentsApiService { + /** + * Batch-fetch comments for multiple messages in one request + */ + getBatchComments = async (request: GetBatchCommentsRequest) => { + const parsed = getBatchCommentsRequest.safeParse(request); + + if (!parsed.success) { + const errorMessage = parsed.error.issues.map((issue) => issue.message).join(", "); + throw new ValidationError(`Invalid request: ${errorMessage}`); + } + + return baseApiService.post("/api/v1/messages/comments/batch", getBatchCommentsResponse, { + body: { message_ids: parsed.data.message_ids }, + }); + }; + /** * Get comments for a message */ From f4b2ab0899fe83cf411321f76611403809184e3b Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Fri, 27 Feb 2026 17:56:00 -0800 Subject: [PATCH 05/10] feat: enhance caching mechanisms to prevent memory leaks - Improved in-memory rate limiting by evicting timestamps outside the current window and cleaning up empty keys. - Updated LLM router service to cache context profiles and avoid redundant computations. - Introduced cache eviction logic for MCP tools and sandbox instances to manage memory usage effectively. - Added garbage collection triggers in chat streaming functions to reclaim resources promptly. --- .../app/agents/new_chat/llm_config.py | 3 +- .../app/agents/new_chat/sandbox.py | 7 ++ .../app/agents/new_chat/tools/mcp_tool.py | 22 ++++ surfsense_backend/app/app.py | 16 +-- .../app/services/llm_router_service.py | 118 ++++++++++-------- surfsense_backend/app/services/llm_service.py | 3 +- .../app/tasks/chat/stream_new_chat.py | 18 +++ 7 files changed, 127 insertions(+), 60 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py index bf16b2fe9..2b1c07cda 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/new_chat/llm_config.py @@ -22,6 +22,7 @@ from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, LLMRouterService, + get_auto_mode_llm, is_auto_mode, ) @@ -389,7 +390,7 @@ def create_chat_litellm_from_agent_config( print("Error: Auto mode requested but LLM Router not initialized") return None try: - return ChatLiteLLMRouter() + return get_auto_mode_llm() except Exception as e: print(f"Error creating ChatLiteLLMRouter: {e}") return None diff --git a/surfsense_backend/app/agents/new_chat/sandbox.py b/surfsense_backend/app/agents/new_chat/sandbox.py index 7696f67f2..8b634993b 100644 --- a/surfsense_backend/app/agents/new_chat/sandbox.py +++ b/surfsense_backend/app/agents/new_chat/sandbox.py @@ -58,6 +58,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox): _daytona_client: Daytona | None = None _sandbox_cache: dict[str, _TimeoutAwareSandbox] = {} +_SANDBOX_CACHE_MAX_SIZE = 20 THREAD_LABEL_KEY = "surfsense_thread" @@ -144,6 +145,12 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox: return cached sandbox = await asyncio.to_thread(_find_or_create, key) _sandbox_cache[key] = sandbox + + if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE: + oldest_key = next(iter(_sandbox_cache)) + _sandbox_cache.pop(oldest_key, None) + logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key) + return sandbox diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 20cf3ec33..2fb7ffb06 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -27,9 +27,24 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType logger = logging.getLogger(__name__) _MCP_CACHE_TTL_SECONDS = 300 # 5 minutes +_MCP_CACHE_MAX_SIZE = 50 _mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {} +def _evict_expired_mcp_cache() -> None: + """Remove expired entries from the MCP tools cache to prevent unbounded growth.""" + now = time.monotonic() + expired = [ + k + for k, (ts, _) in _mcp_tools_cache.items() + if now - ts >= _MCP_CACHE_TTL_SECONDS + ] + for k in expired: + del _mcp_tools_cache[k] + if expired: + logger.debug("Evicted %d expired MCP cache entries", len(expired)) + + def _create_dynamic_input_model_from_schema( tool_name: str, input_schema: dict[str, Any], @@ -392,6 +407,8 @@ async def load_mcp_tools( List of LangChain StructuredTool instances """ + _evict_expired_mcp_cache() + now = time.monotonic() cached = _mcp_tools_cache.get(search_space_id) if cached is not None: @@ -445,6 +462,11 @@ async def load_mcp_tools( ) _mcp_tools_cache[search_space_id] = (now, tools) + + if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE: + oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0]) + del _mcp_tools_cache[oldest_key] + logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}") return tools diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index e8843878f..47cf86ea3 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -103,22 +103,24 @@ def _check_rate_limit_memory( now = time.monotonic() with _memory_lock: - # Evict timestamps outside the current window - _memory_rate_limits[key] = [ - t for t in _memory_rate_limits[key] if now - t < window_seconds - ] + timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds] - if len(_memory_rate_limits[key]) >= max_requests: + if not timestamps: + _memory_rate_limits.pop(key, None) + else: + _memory_rate_limits[key] = timestamps + + if len(timestamps) >= max_requests: rate_limit_logger.warning( f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} " - f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)" + f"({len(timestamps)}/{max_requests} in {window_seconds}s)" ) raise HTTPException( status_code=429, detail="RATE_LIMIT_EXCEEDED", ) - _memory_rate_limits[key].append(now) + _memory_rate_limits[key] = [*timestamps, now] def _check_rate_limit( diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 7839e4014..3bad7be14 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -250,6 +250,48 @@ class LLMRouterService: return len(instance._model_list) +_cached_context_profile: dict | None = None +_cached_context_profile_computed: bool = False + +# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call +_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {} + + +def _get_cached_context_profile(router: Router) -> dict | None: + """Compute and cache the min context profile across all router deployments. + + Called once on first ChatLiteLLMRouter creation; subsequent calls return + the cached value. This avoids calling litellm.get_model_info() for every + deployment on every request. + """ + global _cached_context_profile, _cached_context_profile_computed + if _cached_context_profile_computed: + return _cached_context_profile + + from litellm import get_model_info + + min_ctx: int | None = None + for deployment in router.model_list: + params = deployment.get("litellm_params", {}) + base_model = params.get("base_model") or params.get("model", "") + try: + info = get_model_info(base_model) + ctx = info.get("max_input_tokens") + if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx): + min_ctx = ctx + except Exception: + continue + + if min_ctx is not None: + logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx) + _cached_context_profile = {"max_input_tokens": min_ctx} + else: + _cached_context_profile = None + + _cached_context_profile_computed = True + return _cached_context_profile + + class ChatLiteLLMRouter(BaseChatModel): """ A LangChain-compatible chat model that uses LiteLLM Router for load balancing. @@ -260,6 +302,10 @@ class ChatLiteLLMRouter(BaseChatModel): Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context window across all router deployments so that deepagents SummarizationMiddleware can use fraction-based triggers. + + **Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()`` + directly — instances without bound tools are cached per streaming flag to + avoid per-request re-initialization overhead and memory growth. """ # Use model_config for Pydantic v2 compatibility @@ -281,14 +327,6 @@ class ChatLiteLLMRouter(BaseChatModel): tool_choice: str | dict | None = None, **kwargs, ): - """ - Initialize the ChatLiteLLMRouter. - - Args: - router: LiteLLM Router instance. If None, uses the global singleton. - bound_tools: Pre-bound tools for tool calling - tool_choice: Tool choice configuration - """ try: super().__init__(**kwargs) resolved_router = router or LLMRouterService.get_router() @@ -300,51 +338,20 @@ class ChatLiteLLMRouter(BaseChatModel): "LLM Router not initialized. Call LLMRouterService.initialize() first." ) - # Set profile so deepagents SummarizationMiddleware gets fraction-based triggers - computed_profile = self._compute_min_context_profile() + computed_profile = _get_cached_context_profile(self._router) if computed_profile is not None: object.__setattr__(self, "profile", computed_profile) - logger.info( - f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models" + logger.debug( + "ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)", + LLMRouterService.get_model_count(), + self.streaming, + bound_tools is not None, ) except Exception as e: logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}") raise - def _compute_min_context_profile(self) -> dict | None: - """Derive a profile dict with max_input_tokens from router deployments. - - Uses litellm.get_model_info to look up each deployment's context window - and picks the *minimum* so that summarization triggers before ANY model - in the pool overflows. - """ - from litellm import get_model_info - - if not self._router: - return None - - min_ctx: int | None = None - for deployment in self._router.model_list: - params = deployment.get("litellm_params", {}) - base_model = params.get("base_model") or params.get("model", "") - try: - info = get_model_info(base_model) - ctx = info.get("max_input_tokens") - if ( - isinstance(ctx, int) - and ctx > 0 - and (min_ctx is None or ctx < min_ctx) - ): - min_ctx = ctx - except Exception: - continue - - if min_ctx is not None: - logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}") - return {"max_input_tokens": min_ctx} - return None - @property def _llm_type(self) -> str: return "litellm-router" @@ -770,19 +777,28 @@ class ChatLiteLLMRouter(BaseChatModel): return None -def get_auto_mode_llm() -> ChatLiteLLMRouter | None: - """ - Get a ChatLiteLLMRouter instance for auto mode. +def get_auto_mode_llm( + *, + streaming: bool = True, +) -> ChatLiteLLMRouter | None: + """Return a cached ChatLiteLLMRouter for auto mode. - Returns: - ChatLiteLLMRouter instance or None if router not initialized + Base (no tools) instances are cached per ``streaming`` flag so we + avoid re-constructing them on every request. ``bind_tools()`` still + returns a fresh instance because bound tools differ per agent. """ if not LLMRouterService.is_initialized(): logger.warning("LLM Router not initialized for auto mode") return None + cached = _router_instance_cache.get(streaming) + if cached is not None: + return cached + try: - return ChatLiteLLMRouter() + instance = ChatLiteLLMRouter(streaming=streaming) + _router_instance_cache[streaming] = instance + return instance except Exception as e: logger.error(f"Failed to create ChatLiteLLMRouter: {e}") return None diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 4833e62a6..c91df391c 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -12,6 +12,7 @@ from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, LLMRouterService, + get_auto_mode_llm, is_auto_mode, ) @@ -221,7 +222,7 @@ async def get_search_space_llm_instance( logger.debug( f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}" ) - return ChatLiteLLMRouter(disable_streaming=disable_streaming) + return get_auto_mode_llm(streaming=not disable_streaming) except Exception as e: logger.error(f"Failed to create ChatLiteLLMRouter: {e}") return None diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 9e91a8735..3eaf993ff 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -10,6 +10,7 @@ Supports loading LLM configurations from: """ import asyncio +import gc import json import logging import re @@ -1476,6 +1477,16 @@ async def stream_new_chat( _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + # Trigger a GC pass so LangGraph agent graphs, tool closures, and + # LLM wrappers with potential circular refs are reclaimed promptly. + collected = gc.collect() + if collected: + _perf_log.info( + "[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)", + collected, + chat_id, + ) + async def stream_resume_chat( chat_id: int, @@ -1662,3 +1673,10 @@ async def stream_resume_chat( ) _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + collected = gc.collect() + if collected: + _perf_log.info( + "[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)", + collected, + chat_id, + ) From 1bb9f479e1dcdaa85b5cce8aef4e750c900a623a Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Sat, 28 Feb 2026 01:54:54 -0800 Subject: [PATCH 06/10] feat: refactor document fetching and improve comment batching - Replaced the useDocuments hook with React Query for better caching and deduplication of document requests. - Updated the ConnectorIndicator component to fetch document type counts using a new atom for real-time updates. - Enhanced the useComments hook to manage batch requests more effectively, reducing race conditions and improving performance. - Set default query options in the query client to optimize stale time and refetch behavior. --- .../assistant-ui/connector-popup.tsx | 9 +- surfsense_web/hooks/use-comments.ts | 31 ++++++- surfsense_web/hooks/use-documents.ts | 88 +++++++++---------- surfsense_web/lib/query-client/client.ts | 9 +- 4 files changed, 80 insertions(+), 57 deletions(-) diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx index 98964013d..332694676 100644 --- a/surfsense_web/components/assistant-ui/connector-popup.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup.tsx @@ -5,6 +5,7 @@ import { AlertTriangle, Cable, Settings } from "lucide-react"; import Link from "next/link"; import { useSearchParams } from "next/navigation"; import type { FC } from "react"; +import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms"; import { globalNewLLMConfigsAtom, llmPreferencesAtom, @@ -19,7 +20,6 @@ import { Spinner } from "@/components/ui/spinner"; import { Tabs, TabsContent } from "@/components/ui/tabs"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { useConnectorsElectric } from "@/hooks/use-connectors-electric"; -import { useDocuments } from "@/hooks/use-documents"; import { useInbox } from "@/hooks/use-inbox"; import { cn } from "@/lib/utils"; import { ConnectorDialogHeader } from "./connector-popup/components/connector-dialog-header"; @@ -62,10 +62,9 @@ export const ConnectorIndicator: FC<{ hideTrigger?: boolean }> = ({ hideTrigger const llmConfigLoading = preferencesLoading || globalConfigsLoading; - // Fetch document type counts using Electric SQL + PGlite for real-time updates - const { typeCounts: documentTypeCounts, loading: documentTypesLoading } = useDocuments( - searchSpaceId ? Number(searchSpaceId) : null - ); + // Fetch document type counts via the lightweight /type-counts endpoint (cached 10 min) + const { data: documentTypeCounts, isFetching: documentTypesLoading } = + useAtomValue(documentTypeCountsAtom); // Fetch notifications to detect indexing failures const { inboxItems = [] } = useInbox( diff --git a/surfsense_web/hooks/use-comments.ts b/surfsense_web/hooks/use-comments.ts index 562f7ae02..c02f9fe16 100644 --- a/surfsense_web/hooks/use-comments.ts +++ b/surfsense_web/hooks/use-comments.ts @@ -11,9 +11,26 @@ interface UseCommentsOptions { // --------------------------------------------------------------------------- // Module-level coordination: when a batch request is in-flight, individual // useComments queryFns piggy-back on it instead of making their own requests. +// +// _batchReady is a promise that resolves once the batch useEffect has had a +// chance to set _batchInflight. Individual queryFns await this gate before +// deciding whether to piggy-back or fetch on their own, eliminating the +// previous race where setTimeout(0) was not enough. // --------------------------------------------------------------------------- let _batchInflight: Promise | null = null; let _batchTargetIds = new Set(); +let _batchReady: Promise | null = null; +let _resolveBatchReady: (() => void) | null = null; + +function resetBatchGate() { + _batchReady = new Promise((r) => { + _resolveBatchReady = r; + }); +} + +// Open the initial gate immediately (no batch pending yet) +resetBatchGate(); +_resolveBatchReady?.(); export function useComments({ messageId, enabled = true }: UseCommentsOptions) { const queryClient = useQueryClient(); @@ -21,9 +38,11 @@ export function useComments({ messageId, enabled = true }: UseCommentsOptions) { return useQuery({ queryKey: cacheKeys.comments.byMessage(messageId), queryFn: async () => { - // Yield one macro-task so the batch prefetch useEffect (which sets - // _batchInflight) has a chance to fire before we decide to fetch. - await new Promise((r) => setTimeout(r, 0)); + // Wait for the batch gate so the useEffect in useBatchCommentsPreload + // has a chance to set _batchInflight before we decide. + if (_batchReady) { + await _batchReady; + } if (_batchInflight && _batchTargetIds.has(messageId)) { await _batchInflight; @@ -57,6 +76,9 @@ export function useBatchCommentsPreload(messageIds: number[]) { if (key === prevKeyRef.current) return; prevKeyRef.current = key; + // Open a new gate so individual queryFns wait for us + resetBatchGate(); + _batchTargetIds = new Set(messageIds); let cancelled = false; @@ -80,6 +102,9 @@ export function useBatchCommentsPreload(messageIds: number[]) { _batchInflight = promise; + // Release the gate — individual queryFns can now check _batchInflight + _resolveBatchReady?.(); + return () => { cancelled = true; if (_batchInflight === promise) { diff --git a/surfsense_web/hooks/use-documents.ts b/surfsense_web/hooks/use-documents.ts index 55d48c4f1..36a359696 100644 --- a/surfsense_web/hooks/use-documents.ts +++ b/surfsense_web/hooks/use-documents.ts @@ -1,5 +1,6 @@ "use client"; +import { useQuery } from "@tanstack/react-query"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { documentsApiService } from "@/lib/apis/documents-api.service"; @@ -183,56 +184,47 @@ export function useDocuments( [] ); - // EFFECT 1: Load ALL documents from API (PRIMARY source of truth) - // No type filter — always fetches everything so typeCounts stay complete + // STEP 1: Load ALL documents from API (PRIMARY source of truth). + // Uses React Query for automatic deduplication, caching, and staleTime so + // multiple components mounting useDocuments(sameId) share a single request. + const { + data: apiResponse, + isLoading: apiLoading, + error: apiError, + } = useQuery({ + queryKey: ["documents", "all", searchSpaceId], + queryFn: () => + documentsApiService.getDocuments({ + queryParams: { + search_space_id: searchSpaceId!, + page: 0, + page_size: -1, + }, + }), + enabled: !!searchSpaceId, + staleTime: 30_000, + }); + + // Seed local state from API response (runs once per fresh fetch) useEffect(() => { - if (!searchSpaceId) { - setLoading(false); - return; + if (!apiResponse) return; + populateUserCache(apiResponse.items); + const docs = apiResponse.items.map(apiToDisplayDoc); + setAllDocuments(docs); + apiLoadedRef.current = true; + setError(null); + }, [apiResponse, populateUserCache, apiToDisplayDoc]); + + // Propagate loading / error from React Query + useEffect(() => { + setLoading(apiLoading); + }, [apiLoading]); + + useEffect(() => { + if (apiError) { + setError(apiError instanceof Error ? apiError : new Error("Failed to load documents")); } - - // Capture validated value for async closure - const spaceId = searchSpaceId; - - let mounted = true; - apiLoadedRef.current = false; - - async function loadFromApi() { - try { - setLoading(true); - console.log("[useDocuments] Loading from API (source of truth):", spaceId); - - const response = await documentsApiService.getDocuments({ - queryParams: { - search_space_id: spaceId, - page: 0, - page_size: -1, // Fetch all documents (unfiltered) - }, - }); - - if (!mounted) return; - - populateUserCache(response.items); - const docs = response.items.map(apiToDisplayDoc); - setAllDocuments(docs); - apiLoadedRef.current = true; - setError(null); - console.log("[useDocuments] API loaded", docs.length, "documents"); - } catch (err) { - if (!mounted) return; - console.error("[useDocuments] API load failed:", err); - setError(err instanceof Error ? err : new Error("Failed to load documents")); - } finally { - if (mounted) setLoading(false); - } - } - - loadFromApi(); - - return () => { - mounted = false; - }; - }, [searchSpaceId, populateUserCache, apiToDisplayDoc]); + }, [apiError]); // EFFECT 2: Start Electric sync + live query for real-time updates // No type filter — syncs and queries ALL documents; filtering is client-side diff --git a/surfsense_web/lib/query-client/client.ts b/surfsense_web/lib/query-client/client.ts index 6c7b9ded3..0dcc2ef03 100644 --- a/surfsense_web/lib/query-client/client.ts +++ b/surfsense_web/lib/query-client/client.ts @@ -1,3 +1,10 @@ import { QueryClient } from "@tanstack/react-query"; -export const queryClient = new QueryClient(); +export const queryClient = new QueryClient({ + defaultOptions: { + queries: { + staleTime: 30_000, + refetchOnWindowFocus: false, + }, + }, +}); From d959a6a6c8d14d78062b0b07eab8166f8a51d101 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Sat, 28 Feb 2026 17:22:34 -0800 Subject: [PATCH 07/10] feat: optimize document upload process and enhance memory management - Increased maximum file upload limit from 10 to 50 to improve user experience. - Implemented batch processing for document uploads to avoid proxy timeouts, splitting files into manageable chunks. - Enhanced garbage collection in chat streaming functions to prevent memory leaks and improve performance. - Added memory delta tracking in system snapshots for better monitoring of resource usage. - Updated LLM router and service configurations to prevent unbounded internal accumulation and improve efficiency. --- surfsense_backend/app/app.py | 18 +-- surfsense_backend/app/db.py | 9 +- .../app/routes/documents_routes.py | 125 ++++++++---------- .../app/services/llm_router_service.py | 4 + surfsense_backend/app/services/llm_service.py | 7 + .../app/tasks/celery_tasks/__init__.py | 27 ++++ .../app/tasks/celery_tasks/connector_tasks.py | 19 +-- .../celery_tasks/document_reindex_tasks.py | 14 +- .../app/tasks/celery_tasks/document_tasks.py | 18 +-- .../app/tasks/celery_tasks/podcast_tasks.py | 17 +-- .../celery_tasks/schedule_checker_task.py | 14 +- .../stale_notification_cleanup_task.py | 15 +-- .../app/tasks/chat/stream_new_chat.py | 19 ++- surfsense_backend/app/utils/perf.py | 35 ++++- .../components/sources/DocumentUploadTab.tsx | 4 +- .../lib/apis/documents-api.service.ts | 61 +++++++-- 16 files changed, 219 insertions(+), 187 deletions(-) diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 47cf86ea3..e6db5670e 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -1,4 +1,5 @@ import asyncio +import gc import logging import time from collections import defaultdict @@ -212,18 +213,16 @@ def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None: @asynccontextmanager async def lifespan(app: FastAPI): - # Enable slow-callback detection (set PERF_DEBUG=1 env var to activate) + # Tune GC: lower gen-2 threshold so long-lived garbage is collected + # sooner (default 700/10/10 → 700/10/5). This reduces peak RSS + # with minimal CPU overhead. + gc.set_threshold(700, 10, 5) + _enable_slow_callback_logging(threshold_sec=0.5) - # Not needed if you setup a migration system like Alembic await create_db_and_tables() - # Setup LangGraph checkpointer tables for conversation persistence await setup_checkpointer_tables() - # Initialize LLM Router for Auto mode load balancing initialize_llm_router() - # Initialize Image Generation Router for Auto mode load balancing initialize_image_gen_router() - # Seed Surfsense documentation (with timeout so a slow embedding API - # doesn't block startup indefinitely and make the container unresponsive) try: await asyncio.wait_for(seed_surfsense_docs(), timeout=120) except TimeoutError: @@ -231,8 +230,11 @@ async def lifespan(app: FastAPI): "Surfsense docs seeding timed out after 120s — skipping. " "Docs will be indexed on the next restart." ) + + log_system_snapshot("startup_complete") + yield - # Cleanup: close checkpointer connection on shutdown + await close_checkpointer() diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 771689a13..ba926c9ad 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1856,7 +1856,14 @@ class RefreshToken(Base, TimestampMixin): return not self.is_expired and not self.is_revoked -engine = create_async_engine(DATABASE_URL) +engine = create_async_engine( + DATABASE_URL, + pool_size=30, + max_overflow=150, + pool_recycle=1800, + pool_pre_ping=True, + pool_timeout=30, +) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index 1ce5082ca..865fdf7b3 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -133,6 +133,8 @@ async def create_documents_file_upload( Requires DOCUMENTS_CREATE permission. """ + import os + import tempfile from datetime import datetime from app.db import DocumentStatus @@ -143,7 +145,6 @@ async def create_documents_file_upload( from app.utils.document_converters import generate_unique_identifier_hash try: - # Check permission await check_permission( session, user, @@ -179,69 +180,64 @@ async def create_documents_file_upload( f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.", ) - created_documents: list[Document] = [] - files_to_process: list[ - tuple[Document, str, str] - ] = [] # (document, temp_path, filename) - skipped_duplicates = 0 - duplicate_document_ids: list[int] = [] - actual_total_size = 0 + # ===== Read all files concurrently to avoid blocking the event loop ===== + async def _read_and_save(file: UploadFile) -> tuple[str, str, int]: + """Read upload content and write to temp file off the event loop.""" + content = await file.read() + file_size = len(content) + filename = file.filename or "unknown" - # ===== PHASE 1: Create pending documents for all files ===== - # This makes ALL documents visible in the UI immediately with pending status - for file in files: - try: - import os - import tempfile - - # Save file to temp location - with tempfile.NamedTemporaryFile( - delete=False, suffix=os.path.splitext(file.filename or "")[1] - ) as temp_file: - temp_path = temp_file.name - - content = await file.read() - file_size = len(content) - - if file_size > MAX_FILE_SIZE_BYTES: - os.unlink(temp_path) - raise HTTPException( - status_code=413, - detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) " - f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.", - ) - - actual_total_size += file_size - if actual_total_size > MAX_TOTAL_SIZE_BYTES: - os.unlink(temp_path) - raise HTTPException( - status_code=413, - detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) " - f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.", - ) - - with open(temp_path, "wb") as f: - f.write(content) - - # Generate unique identifier for deduplication check - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.FILE, file.filename or "unknown", search_space_id + if file_size > MAX_FILE_SIZE_BYTES: + raise HTTPException( + status_code=413, + detail=f"File '{filename}' ({file_size / (1024 * 1024):.1f} MB) " + f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.", + ) + + def _write_temp() -> str: + with tempfile.NamedTemporaryFile( + delete=False, suffix=os.path.splitext(filename)[1] + ) as tmp: + tmp.write(content) + return tmp.name + + temp_path = await asyncio.to_thread(_write_temp) + return temp_path, filename, file_size + + saved_files = await asyncio.gather(*(_read_and_save(f) for f in files)) + + actual_total_size = sum(size for _, _, size in saved_files) + if actual_total_size > MAX_TOTAL_SIZE_BYTES: + for temp_path, _, _ in saved_files: + os.unlink(temp_path) + raise HTTPException( + status_code=413, + detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) " + f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.", + ) + + # ===== PHASE 1: Create pending documents for all files ===== + created_documents: list[Document] = [] + files_to_process: list[tuple[Document, str, str]] = [] + skipped_duplicates = 0 + duplicate_document_ids: list[int] = [] + + for temp_path, filename, file_size in saved_files: + try: + unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.FILE, filename, search_space_id ) - # Check if document already exists (by unique identifier) existing = await check_document_by_unique_identifier( session, unique_identifier_hash ) if existing: if DocumentStatus.is_state(existing.status, DocumentStatus.READY): - # True duplicate — content already indexed, skip os.unlink(temp_path) skipped_duplicates += 1 duplicate_document_ids.append(existing.id) continue - # Existing document is stuck (failed/pending/processing) - # Reset it to pending and re-dispatch for processing existing.status = DocumentStatus.pending() existing.content = "Processing..." existing.document_metadata = { @@ -251,50 +247,45 @@ async def create_documents_file_upload( } existing.updated_at = get_current_timestamp() created_documents.append(existing) - files_to_process.append( - (existing, temp_path, file.filename or "unknown") - ) + files_to_process.append((existing, temp_path, filename)) continue - # Create pending document (visible immediately in UI via ElectricSQL) document = Document( search_space_id=search_space_id, - title=file.filename or "Uploaded File", + title=filename if filename != "unknown" else "Uploaded File", document_type=DocumentType.FILE, document_metadata={ - "FILE_NAME": file.filename, + "FILE_NAME": filename, "file_size": file_size, "upload_time": datetime.now().isoformat(), }, - content="Processing...", # Placeholder until processed - content_hash=unique_identifier_hash, # Temporary, updated when ready + content="Processing...", + content_hash=unique_identifier_hash, unique_identifier_hash=unique_identifier_hash, embedding=None, - status=DocumentStatus.pending(), # Shows "pending" in UI + status=DocumentStatus.pending(), updated_at=get_current_timestamp(), created_by_id=str(user.id), ) session.add(document) created_documents.append(document) - files_to_process.append( - (document, temp_path, file.filename or "unknown") - ) + files_to_process.append((document, temp_path, filename)) + except HTTPException: + raise except Exception as e: + os.unlink(temp_path) raise HTTPException( status_code=422, - detail=f"Failed to process file {file.filename}: {e!s}", + detail=f"Failed to process file {filename}: {e!s}", ) from e - # Commit all pending documents - they appear in UI immediately via ElectricSQL if created_documents: await session.commit() - # Refresh to get generated IDs for doc in created_documents: await session.refresh(doc) # ===== PHASE 2: Dispatch tasks for each file ===== - # Each task will update document status: pending → processing → ready/failed for document, temp_path, filename in files_to_process: await dispatcher.dispatch_file_processing( document_id=document.id, diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 3bad7be14..2465834f4 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -16,6 +16,7 @@ import re import time from typing import Any +import litellm from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.exceptions import ContextOverflowError from langchain_core.language_models import BaseChatModel @@ -29,6 +30,9 @@ from litellm.exceptions import ( from app.utils.perf import get_perf_logger +litellm.json_logs = False +litellm.store_audit_logs = False + logger = logging.getLogger(__name__) _CONTEXT_OVERFLOW_PATTERNS = re.compile( diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index c91df391c..fc28f477f 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -19,6 +19,13 @@ from app.services.llm_router_service import ( # Configure litellm to automatically drop unsupported parameters litellm.drop_params = True +# Memory controls: prevent unbounded internal accumulation +litellm.telemetry = False +litellm.cache = None +litellm.success_callback = [] +litellm.failure_callback = [] +litellm.input_callback = [] + logger = logging.getLogger(__name__) diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py index 9abc472fe..5b1f2cd13 100644 --- a/surfsense_backend/app/tasks/celery_tasks/__init__.py +++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py @@ -1 +1,28 @@ """Celery tasks package.""" + +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.pool import NullPool + +from app.config import config + +_celery_engine = None +_celery_session_maker = None + + +def get_celery_session_maker() -> async_sessionmaker: + """Return a shared async session maker for Celery tasks. + + A single NullPool engine is created per worker process and reused + across all task invocations to avoid leaking engine objects. + """ + global _celery_engine, _celery_session_maker + if _celery_session_maker is None: + _celery_engine = create_async_engine( + config.DATABASE_URL, + poolclass=NullPool, + echo=False, + ) + _celery_session_maker = async_sessionmaker( + _celery_engine, expire_on_commit=False + ) + return _celery_session_maker diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index a35528a93..9d52add9c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -3,11 +3,8 @@ import logging import traceback -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from sqlalchemy.pool import NullPool - from app.celery_app import celery_app -from app.config import config +from app.tasks.celery_tasks import get_celery_session_maker logger = logging.getLogger(__name__) @@ -42,20 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N ) -def get_celery_session_maker(): - """ - Create a new async session maker for Celery tasks. - This is necessary because Celery tasks run in a new event loop, - and the default session maker is bound to the main app's event loop. - """ - engine = create_async_engine( - config.DATABASE_URL, - poolclass=NullPool, # Don't use connection pooling for Celery tasks - echo=False, - ) - return async_sessionmaker(engine, expire_on_commit=False) - - @celery_app.task(name="index_slack_messages", bind=True) def index_slack_messages_task( self, diff --git a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py index a2a0d635d..a1fca469e 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py @@ -4,15 +4,13 @@ import logging from sqlalchemy import delete, select from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import selectinload -from sqlalchemy.pool import NullPool from app.celery_app import celery_app -from app.config import config from app.db import Document from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService +from app.tasks.celery_tasks import get_celery_session_maker from app.utils.document_converters import ( create_document_chunks, generate_document_summary, @@ -21,16 +19,6 @@ from app.utils.document_converters import ( logger = logging.getLogger(__name__) -def get_celery_session_maker(): - """Create async session maker for Celery tasks.""" - engine = create_async_engine( - config.DATABASE_URL, - poolclass=NullPool, - echo=False, - ) - return async_sessionmaker(engine, expire_on_commit=False) - - @celery_app.task(name="reindex_document", bind=True) def reindex_document_task(self, document_id: int, user_id: str): """ diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 60cd21f97..dcb791d3b 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -5,13 +5,11 @@ import logging import os from uuid import UUID -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from sqlalchemy.pool import NullPool - from app.celery_app import celery_app from app.config import config from app.services.notification_service import NotificationService from app.services.task_logging_service import TaskLoggingService +from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.document_processors import ( add_extension_received_document, add_youtube_video_document, @@ -91,20 +89,6 @@ async def _run_heartbeat_loop(notification_id: int): pass # Normal cancellation when task completes -def get_celery_session_maker(): - """ - Create a new async session maker for Celery tasks. - This is necessary because Celery tasks run in a new event loop, - and the default session maker is bound to the main app's event loop. - """ - engine = create_async_engine( - config.DATABASE_URL, - poolclass=NullPool, # Don't use connection pooling for Celery tasks - echo=False, - ) - return async_sessionmaker(engine, expire_on_commit=False) - - @celery_app.task(name="process_extension_document", bind=True) def process_extension_document_task( self, individual_document_dict, search_space_id: int, user_id: str diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 973e7e750..42378fe5e 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -5,14 +5,13 @@ import logging import sys from sqlalchemy import select -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from sqlalchemy.pool import NullPool from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.state import State as PodcasterState from app.celery_app import celery_app from app.config import config from app.db import Podcast, PodcastStatus +from app.tasks.celery_tasks import get_celery_session_maker logger = logging.getLogger(__name__) @@ -25,20 +24,6 @@ if sys.platform.startswith("win"): ) -def get_celery_session_maker(): - """ - Create a new async session maker for Celery tasks. - This is necessary because Celery tasks run in a new event loop, - and the default session maker is bound to the main app's event loop. - """ - engine = create_async_engine( - config.DATABASE_URL, - poolclass=NullPool, # Don't use connection pooling for Celery tasks - echo=False, - ) - return async_sessionmaker(engine, expire_on_commit=False) - - # ============================================================================= # Content-based podcast generation (for new-chat) # ============================================================================= diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index 80d271aaa..0ba8bc80a 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -3,28 +3,16 @@ import logging from datetime import UTC, datetime -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.future import select -from sqlalchemy.pool import NullPool from app.celery_app import celery_app -from app.config import config from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType +from app.tasks.celery_tasks import get_celery_session_maker from app.utils.indexing_locks import is_connector_indexing_locked logger = logging.getLogger(__name__) -def get_celery_session_maker(): - """Create async session maker for Celery tasks.""" - engine = create_async_engine( - config.DATABASE_URL, - poolclass=NullPool, - echo=False, - ) - return async_sessionmaker(engine, expire_on_commit=False) - - @celery_app.task(name="check_periodic_schedules") def check_periodic_schedules_task(): """ diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py index c2c82dd2c..e05ae9435 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py @@ -29,20 +29,17 @@ from datetime import UTC, datetime import redis from sqlalchemy import and_, or_, text -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.future import select -from sqlalchemy.pool import NullPool from app.celery_app import celery_app from app.config import config from app.db import Document, DocumentStatus, Notification +from app.tasks.celery_tasks import get_celery_session_maker logger = logging.getLogger(__name__) -# Redis client for checking heartbeats _redis_client: redis.Redis | None = None -# Error messages shown to users when tasks are interrupted STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry." STALE_PROCESSING_ERROR_MESSAGE = "Syncing was interrupted unexpectedly. Please retry." @@ -60,16 +57,6 @@ def _get_heartbeat_key(notification_id: int) -> str: return f"indexing:heartbeat:{notification_id}" -def get_celery_session_maker(): - """Create async session maker for Celery tasks.""" - engine = create_async_engine( - config.DATABASE_URL, - poolclass=NullPool, - echo=False, - ) - return async_sessionmaker(engine, expire_on_commit=False) - - @celery_app.task(name="cleanup_stale_indexing_notifications") def cleanup_stale_indexing_notifications_task(): """ diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 3eaf993ff..ae7001a40 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1477,15 +1477,21 @@ async def stream_new_chat( _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) - # Trigger a GC pass so LangGraph agent graphs, tool closures, and - # LLM wrappers with potential circular refs are reclaimed promptly. - collected = gc.collect() + # Break circular refs held by the agent graph, tools, and LLM + # wrappers so the GC can reclaim them in a single pass. + agent = llm = connector_service = sandbox_backend = None + mentioned_documents = mentioned_surfsense_docs = None + recent_reports = langchain_messages = input_state = None + stream_result = None + + collected = gc.collect(0) + gc.collect(1) + gc.collect(2) if collected: _perf_log.info( "[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)", collected, chat_id, ) + log_system_snapshot("stream_new_chat_END") async def stream_resume_chat( @@ -1673,10 +1679,15 @@ async def stream_resume_chat( ) _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) - collected = gc.collect() + + agent = llm = connector_service = sandbox_backend = None + stream_result = None + + collected = gc.collect(0) + gc.collect(1) + gc.collect(2) if collected: _perf_log.info( "[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)", collected, chat_id, ) + log_system_snapshot("stream_resume_chat_END") diff --git a/surfsense_backend/app/utils/perf.py b/surfsense_backend/app/utils/perf.py index 301498048..d6602d666 100644 --- a/surfsense_backend/app/utils/perf.py +++ b/surfsense_backend/app/utils/perf.py @@ -9,6 +9,7 @@ Provides: - RequestPerfMiddleware for per-request timing """ +import gc import logging import os import time @@ -16,6 +17,7 @@ from contextlib import asynccontextmanager, contextmanager from typing import Any _perf_log: logging.Logger | None = None +_last_rss_mb: float = 0.0 def get_perf_logger() -> logging.Logger: @@ -73,20 +75,29 @@ def system_snapshot() -> dict[str, Any]: Returns a dict with: - rss_mb: Resident Set Size in MB + - rss_delta_mb: Change in RSS since the last snapshot - cpu_percent: CPU usage % since last call (per-process) - threads: number of active threads - open_fds: number of open file descriptors (Linux only) - asyncio_tasks: number of asyncio tasks currently alive + - gc_counts: tuple of object counts per gc generation """ import asyncio + global _last_rss_mb + snapshot: dict[str, Any] = {} try: import psutil proc = psutil.Process(os.getpid()) mem = proc.memory_info() - snapshot["rss_mb"] = round(mem.rss / 1024 / 1024, 1) + rss_mb = round(mem.rss / 1024 / 1024, 1) + snapshot["rss_mb"] = rss_mb + snapshot["rss_delta_mb"] = ( + round(rss_mb - _last_rss_mb, 1) if _last_rss_mb else 0.0 + ) + _last_rss_mb = rss_mb snapshot["cpu_percent"] = proc.cpu_percent(interval=None) snapshot["threads"] = proc.num_threads() try: @@ -95,6 +106,7 @@ def system_snapshot() -> dict[str, Any]: snapshot["open_fds"] = -1 except ImportError: snapshot["rss_mb"] = -1 + snapshot["rss_delta_mb"] = 0.0 snapshot["cpu_percent"] = -1 snapshot["threads"] = -1 snapshot["open_fds"] = -1 @@ -105,18 +117,35 @@ def system_snapshot() -> dict[str, Any]: except RuntimeError: snapshot["asyncio_tasks"] = -1 + snapshot["gc_counts"] = gc.get_count() + return snapshot def log_system_snapshot(label: str = "system_snapshot") -> None: - """Capture and log a system snapshot.""" + """Capture and log a system snapshot with memory delta tracking.""" snap = system_snapshot() + delta_str = "" + if snap["rss_delta_mb"]: + sign = "+" if snap["rss_delta_mb"] > 0 else "" + delta_str = f" delta={sign}{snap['rss_delta_mb']}MB" get_perf_logger().info( - "[%s] rss=%.1fMB cpu=%.1f%% threads=%d fds=%d asyncio_tasks=%d", + "[%s] rss=%.1fMB%s cpu=%.1f%% threads=%d fds=%d asyncio_tasks=%d gc=%s", label, snap["rss_mb"], + delta_str, snap["cpu_percent"], snap["threads"], snap["open_fds"], snap["asyncio_tasks"], + snap["gc_counts"], ) + + if snap["rss_mb"] > 0 and snap["rss_delta_mb"] > 500: + get_perf_logger().warning( + "[MEMORY_SPIKE] %s: RSS jumped by %.1fMB (now %.1fMB). " + "Possible leak — check recent operations.", + label, + snap["rss_delta_mb"], + snap["rss_mb"], + ) diff --git a/surfsense_web/components/sources/DocumentUploadTab.tsx b/surfsense_web/components/sources/DocumentUploadTab.tsx index caea98890..cae78f7b7 100644 --- a/surfsense_web/components/sources/DocumentUploadTab.tsx +++ b/surfsense_web/components/sources/DocumentUploadTab.tsx @@ -111,8 +111,8 @@ const FILE_TYPE_CONFIG: Record> = { const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5"; -// Upload limits -const MAX_FILES = 10; +// Upload limits — files are sent in batches of 5 to avoid proxy timeouts +const MAX_FILES = 50; const MAX_TOTAL_SIZE_MB = 200; const MAX_TOTAL_SIZE_BYTES = MAX_TOTAL_SIZE_MB * 1024 * 1024; diff --git a/surfsense_web/lib/apis/documents-api.service.ts b/surfsense_web/lib/apis/documents-api.service.ts index e3ee2bd5b..9b0d847f4 100644 --- a/surfsense_web/lib/apis/documents-api.service.ts +++ b/surfsense_web/lib/apis/documents-api.service.ts @@ -109,7 +109,9 @@ class DocumentsApiService { }; /** - * Upload document files + * Upload document files in batches to avoid proxy/LB timeouts. + * Files are split into chunks of UPLOAD_BATCH_SIZE and sent as separate + * requests. Results are aggregated into a single response. */ uploadDocument = async (request: UploadDocumentRequest) => { const parsedRequest = uploadDocumentRequest.safeParse(request); @@ -121,17 +123,54 @@ class DocumentsApiService { throw new ValidationError(`Invalid request: ${errorMessage}`); } - // Create FormData for file upload - const formData = new FormData(); - parsedRequest.data.files.forEach((file) => { - formData.append("files", file); - }); - formData.append("search_space_id", String(parsedRequest.data.search_space_id)); - formData.append("should_summarize", String(parsedRequest.data.should_summarize)); + const { files, search_space_id, should_summarize } = parsedRequest.data; + const UPLOAD_BATCH_SIZE = 5; - return baseApiService.postFormData(`/api/v1/documents/fileupload`, uploadDocumentResponse, { - body: formData, - }); + const batches: File[][] = []; + for (let i = 0; i < files.length; i += UPLOAD_BATCH_SIZE) { + batches.push(files.slice(i, i + UPLOAD_BATCH_SIZE)); + } + + const allDocumentIds: number[] = []; + const allDuplicateIds: number[] = []; + let totalFiles = 0; + let pendingFiles = 0; + let skippedDuplicates = 0; + + for (const batch of batches) { + const formData = new FormData(); + batch.forEach((file) => formData.append("files", file)); + formData.append("search_space_id", String(search_space_id)); + formData.append("should_summarize", String(should_summarize)); + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 120_000); + + try { + const result = await baseApiService.postFormData( + `/api/v1/documents/fileupload`, + uploadDocumentResponse, + { body: formData, signal: controller.signal } + ); + + allDocumentIds.push(...(result.document_ids ?? [])); + allDuplicateIds.push(...(result.duplicate_document_ids ?? [])); + totalFiles += result.total_files ?? batch.length; + pendingFiles += result.pending_files ?? 0; + skippedDuplicates += result.skipped_duplicates ?? 0; + } finally { + clearTimeout(timeoutId); + } + } + + return { + message: "Files uploaded for processing" as const, + document_ids: allDocumentIds, + duplicate_document_ids: allDuplicateIds, + total_files: totalFiles, + pending_files: pendingFiles, + skipped_duplicates: skippedDuplicates, + }; }; /** From 40a091f8cc2dafac4c2e5ceea4ff2f321251d9d0 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Sat, 28 Feb 2026 19:40:24 -0800 Subject: [PATCH 08/10] feat: enhance knowledge base search and document retrieval - Introduced a mechanism to identify degenerate queries that lack meaningful search signals, improving search accuracy. - Implemented a fallback method for browsing recent documents when queries are degenerate, ensuring relevant results are returned. - Added limits on the number of chunks fetched per document to optimize performance and prevent excessive data loading. - Updated the ConnectorService to allow for reusable query embeddings, enhancing efficiency in search operations. - Enhanced LLM router service to support context window fallbacks, improving robustness during context window limitations. --- .../agents/new_chat/tools/knowledge_base.py | 423 ++++++++++++++---- .../app/agents/new_chat/tools/registry.py | 6 +- .../app/agents/new_chat/tools/report.py | 2 + .../app/retriever/chunks_hybrid_search.py | 22 +- .../app/retriever/documents_hybrid_search.py | 16 +- .../app/services/connector_service.py | 18 +- .../app/services/llm_router_service.py | 89 +++- 7 files changed, 476 insertions(+), 100 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index 9394d68b4..16cad80e5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -10,6 +10,7 @@ This module provides: import asyncio import json +import re import time from datetime import datetime from typing import Any @@ -22,6 +23,149 @@ from app.db import async_session_maker from app.services.connector_service import ConnectorService from app.utils.perf import get_perf_logger +# Connectors that call external live-search APIs (no local DB / embedding needed). +# These are never filtered by available_document_types. +_LIVE_SEARCH_CONNECTORS: set[str] = { + "TAVILY_API", + "SEARXNG_API", + "LINKUP_API", + "BAIDU_SEARCH_API", +} + +# Patterns that indicate the query has no meaningful search signal. +# plainto_tsquery('english', '*') produces an empty tsquery and an embedding +# of '*' is random noise, so both keyword and semantic search degrade to +# arbitrary ordering — large documents (many chunks) dominate by chance. +_DEGENERATE_QUERY_RE = re.compile( + r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace +) + +# Max chunks per document when doing a recency-based browse instead of +# a real search. We want breadth (many docs) over depth (many chunks). +_BROWSE_MAX_CHUNKS_PER_DOC = 5 + + +def _is_degenerate_query(query: str) -> bool: + """Return True when the query carries no meaningful search signal. + + Catches wildcard patterns (``*``, ``**``), empty / whitespace-only + strings, and single-character non-word tokens. These queries cause + both keyword search (empty tsquery) and semantic search (meaningless + embedding) to return effectively random results. + """ + stripped = query.strip() + if not stripped: + return True + return bool(_DEGENERATE_QUERY_RE.match(stripped)) + + +async def _browse_recent_documents( + search_space_id: int, + document_type: str | None, + top_k: int, + start_date: datetime | None, + end_date: datetime | None, +) -> list[dict[str, Any]]: + """Return the most-recent documents (recency-ordered, no search ranking). + + Used as a fallback when the search query is degenerate (e.g. ``*``) and + semantic / keyword search would produce arbitrary results. Returns + document-grouped dicts in the same shape as ``_combined_rrf_search`` + so the rest of the pipeline works unchanged. + """ + from sqlalchemy import select + from sqlalchemy.orm import joinedload + + from app.db import Chunk, Document, DocumentType + + perf = get_perf_logger() + t0 = time.perf_counter() + + base_conditions = [Document.search_space_id == search_space_id] + + if document_type is not None: + if isinstance(document_type, str): + try: + doc_type_enum = DocumentType[document_type] + base_conditions.append(Document.document_type == doc_type_enum) + except KeyError: + return [] + else: + base_conditions.append(Document.document_type == document_type) + + if start_date is not None: + base_conditions.append(Document.updated_at >= start_date) + if end_date is not None: + base_conditions.append(Document.updated_at <= end_date) + + async with async_session_maker() as session: + doc_query = ( + select(Document) + .options(joinedload(Document.search_space)) + .where(*base_conditions) + .order_by(Document.updated_at.desc()) + .limit(top_k) + ) + result = await session.execute(doc_query) + documents = result.scalars().unique().all() + + if not documents: + return [] + + doc_ids = [d.id for d in documents] + + chunk_query = ( + select(Chunk) + .where(Chunk.document_id.in_(doc_ids)) + .order_by(Chunk.document_id, Chunk.id) + ) + chunk_result = await session.execute(chunk_query) + raw_chunks = chunk_result.scalars().all() + + doc_chunk_counts: dict[int, int] = {} + doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents} + for chunk in raw_chunks: + did = chunk.document_id + count = doc_chunk_counts.get(did, 0) + if count < _BROWSE_MAX_CHUNKS_PER_DOC: + doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content}) + doc_chunk_counts[did] = count + 1 + + results: list[dict[str, Any]] = [] + for doc in documents: + chunks_list = doc_chunks.get(doc.id, []) + results.append( + { + "document_id": doc.id, + "content": "\n\n".join( + c["content"] for c in chunks_list if c.get("content") + ), + "score": 0.0, + "chunks": chunks_list, + "document": { + "id": doc.id, + "title": doc.title, + "document_type": doc.document_type.value + if getattr(doc, "document_type", None) + else None, + "metadata": doc.document_metadata or {}, + }, + "source": doc.document_type.value + if getattr(doc, "document_type", None) + else None, + } + ) + + perf.info( + "[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s", + time.perf_counter() - t0, + len(results), + search_space_id, + document_type, + ) + return results + + # ============================================================================= # Connector Constants and Normalization # ============================================================================= @@ -184,9 +328,23 @@ _CHARS_PER_TOKEN = 4 # Hard-floor / ceiling so the budget is always sensible regardless of what # the model reports. _MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens -_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens +_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens _MAX_CHUNK_CHARS = 8_000 +# Rank-adaptive per-document budget allocation. +# Top-ranked (most relevant) documents get a larger share of the budget so +# we pack as much high-quality context as possible. +# +# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY) +# +# Examples (128K budget, 8K chunk cap): +# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks +# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor) +# rank 2 → 24% → 3 chunks | +_TOP_DOC_BUDGET_FRACTION = 0.40 +_RANK_DECAY = 0.35 +_MIN_CHUNKS_PER_DOC = 3 + def _compute_tool_output_budget(max_input_tokens: int | None) -> int: """Derive a character budget from the model's context window. @@ -208,18 +366,24 @@ def format_documents_for_context( *, max_chars: int = _MAX_TOOL_OUTPUT_CHARS, max_chunk_chars: int = _MAX_CHUNK_CHARS, + max_chunks_per_doc: int = 0, ) -> str: """ Format retrieved documents into a readable context string for the LLM. Documents are added in order (highest relevance first) until the character - budget is reached. Individual chunks are capped at ``max_chunk_chars`` so - a single oversized chunk cannot monopolize the output. + budget is reached. Individual chunks are capped at ``max_chunk_chars`` and + each document is limited to a dynamically computed chunk cap so a single + large document cannot monopolize the output while still maximising the use + of available context space. Args: documents: List of document dictionaries from connector search max_chars: Approximate character budget for the entire output. max_chunk_chars: Per-chunk character cap (content is tail-truncated). + max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means + auto-compute per document using a rank-adaptive formula so + higher-ranked documents receive more chunks. Returns: Formatted string with document contents and metadata @@ -342,7 +506,23 @@ def format_documents_for_context( "", ] - for ch in g["chunks"]: + # Rank-adaptive per-document chunk cap: top results get more chunks. + if max_chunks_per_doc > 0: + chunks_allowed = max_chunks_per_doc + else: + doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY) + max_doc_chars = int(max_chars * doc_fraction) + xml_overhead = 500 + chunks_allowed = max( + (max_doc_chars - xml_overhead) // max(max_chunk_chars, 1), + _MIN_CHUNKS_PER_DOC, + ) + + chunks = g["chunks"] + if len(chunks) > chunks_allowed: + chunks = chunks[:chunks_allowed] + + for ch in chunks: ch_content = ch["content"] if max_chunk_chars and len(ch_content) > max_chunk_chars: ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)" @@ -359,9 +539,11 @@ def format_documents_for_context( doc_xml = "\n".join(doc_lines) doc_len = len(doc_xml) - # Always include at least the first document; afterwards enforce budget. - if doc_idx > 0 and total_chars + doc_len > max_chars: + if total_chars + doc_len > max_chars: remaining = total_docs - doc_idx + if doc_idx == 0: + parts.append(doc_xml) + total_chars += doc_len parts.append( f"" + result = result[: max_chars - len(truncation_msg)] + truncation_msg + + return result # ============================================================================= @@ -390,6 +580,7 @@ async def search_knowledge_base_async( start_date: datetime | None = None, end_date: datetime | None = None, available_connectors: list[str] | None = None, + available_document_types: list[str] | None = None, max_input_tokens: int | None = None, ) -> str: """ @@ -408,6 +599,9 @@ async def search_knowledge_base_async( end_date: Optional end datetime (UTC) for filtering documents available_connectors: Optional list of connectors actually available in the search space. If provided, only these connectors will be searched. + available_document_types: Optional list of document types that actually have indexed + data. When provided, local connectors whose document type is + absent are skipped entirely (no embedding / DB round-trip). max_input_tokens: Model context window size (tokens). Used to dynamically size the output so it fits within the model's limits. @@ -428,6 +622,23 @@ async def search_knowledge_base_async( ) connectors = _normalize_connectors(connectors_to_search, available_connectors) + + # --- Optimization 1: skip local connectors that have zero indexed documents --- + if available_document_types: + doc_types_set = set(available_document_types) + before_count = len(connectors) + connectors = [ + c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set + ] + skipped = before_count - len(connectors) + if skipped: + perf.info( + "[kb_search] skipped %d empty connectors (had %d, now %d)", + skipped, + before_count, + len(connectors), + ) + perf.info( "[kb_search] searching %d connectors: %s (space=%d, top_k=%d)", len(connectors), @@ -436,81 +647,126 @@ async def search_knowledge_base_async( top_k, ) - connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { - "YOUTUBE_VIDEO": ("search_youtube", True, True, {}), - "EXTENSION": ("search_extension", True, True, {}), - "CRAWLED_URL": ("search_crawled_urls", True, True, {}), - "FILE": ("search_files", True, True, {}), - "SLACK_CONNECTOR": ("search_slack", True, True, {}), - "TEAMS_CONNECTOR": ("search_teams", True, True, {}), - "NOTION_CONNECTOR": ("search_notion", True, True, {}), - "GITHUB_CONNECTOR": ("search_github", True, True, {}), - "LINEAR_CONNECTOR": ("search_linear", True, True, {}), + # --- Fast-path: degenerate queries (*, **, empty, etc.) --- + # Semantic embedding of '*' is noise and plainto_tsquery('english', '*') + # yields an empty tsquery, so both retrieval signals are useless. + # Fall back to a recency-ordered browse that returns diverse results. + if _is_degenerate_query(query): + perf.info( + "[kb_search] degenerate query %r detected - falling back to recency browse", + query, + ) + local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS] + if not local_connectors: + local_connectors = [None] # type: ignore[list-item] + + browse_results = await asyncio.gather( + *[ + _browse_recent_documents( + search_space_id=search_space_id, + document_type=c, + top_k=top_k, + start_date=resolved_start_date, + end_date=resolved_end_date, + ) + for c in local_connectors + ] + ) + for docs in browse_results: + all_documents.extend(docs) + + # Skip dedup + formatting below (browse already returns unique docs) + # but still cap output budget. + output_budget = _compute_tool_output_budget(max_input_tokens) + result = format_documents_for_context( + all_documents, + max_chars=output_budget, + max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC, + ) + perf.info( + "[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d " + "budget=%d space=%d", + time.perf_counter() - t0, + len(all_documents), + len(result), + output_budget, + search_space_id, + ) + return result + + # Specs for live-search connectors (external APIs, no local DB/embedding). + live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = { "TAVILY_API": ("search_tavily", False, True, {}), "SEARXNG_API": ("search_searxng", False, True, {}), "LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}), "BAIDU_SEARCH_API": ("search_baidu", False, True, {}), - "DISCORD_CONNECTOR": ("search_discord", True, True, {}), - "JIRA_CONNECTOR": ("search_jira", True, True, {}), - "GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}), - "AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}), - "GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}), - "GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}), - "CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}), - "CLICKUP_CONNECTOR": ("search_clickup", True, True, {}), - "LUMA_CONNECTOR": ("search_luma", True, True, {}), - "ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}), - "NOTE": ("search_notes", True, True, {}), - "BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}), - "CIRCLEBACK": ("search_circleback", True, True, {}), - "OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}), - # Composio connectors - "COMPOSIO_GOOGLE_DRIVE_CONNECTOR": ( - "search_composio_google_drive", - True, - True, - {}, - ), - "COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}), - "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": ( - "search_composio_google_calendar", - True, - True, - {}, - ), } - # Keep a conservative cap to avoid overloading DB/external services. + # --- Optimization 2: compute the query embedding once, share across all local searches --- + precomputed_embedding: list[float] | None = None + has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors) + if has_local_connectors: + from app.config import config as app_config + + t_embed = time.perf_counter() + precomputed_embedding = app_config.embedding_model_instance.embed(query) + perf.info( + "[kb_search] shared embedding computed in %.3fs", + time.perf_counter() - t_embed, + ) + max_parallel_searches = 4 semaphore = asyncio.Semaphore(max_parallel_searches) async def _search_one_connector(connector: str) -> list[dict[str, Any]]: - spec = connector_specs.get(connector) - if spec is None: - return [] + is_live = connector in _LIVE_SEARCH_CONNECTORS - method_name, includes_date_range, includes_top_k, extra_kwargs = spec - kwargs: dict[str, Any] = { - "user_query": query, - "search_space_id": search_space_id, - **extra_kwargs, - } - if includes_top_k: - kwargs["top_k"] = top_k - if includes_date_range: - kwargs["start_date"] = resolved_start_date - kwargs["end_date"] = resolved_end_date + if is_live: + spec = live_connector_specs.get(connector) + if spec is None: + return [] + method_name, includes_date_range, includes_top_k, extra_kwargs = spec + kwargs: dict[str, Any] = { + "user_query": query, + "search_space_id": search_space_id, + **extra_kwargs, + } + if includes_top_k: + kwargs["top_k"] = top_k + if includes_date_range: + kwargs["start_date"] = resolved_start_date + kwargs["end_date"] = resolved_end_date + try: + t_conn = time.perf_counter() + async with semaphore, async_session_maker() as isolated_session: + svc = ConnectorService(isolated_session, search_space_id) + _, chunks = await getattr(svc, method_name)(**kwargs) + perf.info( + "[kb_search] connector=%s results=%d in %.3fs", + connector, + len(chunks), + time.perf_counter() - t_conn, + ) + return chunks + except Exception as e: + perf.warning("[kb_search] connector=%s FAILED: %s", connector, e) + return [] + + # --- Optimization 3: call _combined_rrf_search directly with shared embedding --- try: - # Use isolated session per connector. Shared AsyncSession cannot safely - # run concurrent DB operations. t_conn = time.perf_counter() async with semaphore, async_session_maker() as isolated_session: - isolated_connector_service = ConnectorService( - isolated_session, search_space_id + svc = ConnectorService(isolated_session, search_space_id) + chunks = await svc._combined_rrf_search( + query_text=query, + search_space_id=search_space_id, + document_type=connector, + top_k=top_k, + start_date=resolved_start_date, + end_date=resolved_end_date, + query_embedding=precomputed_embedding, ) - connector_method = getattr(isolated_connector_service, method_name) - _, chunks = await connector_method(**kwargs) perf.info( "[kb_search] connector=%s results=%d in %.3fs", connector, @@ -519,12 +775,7 @@ async def search_knowledge_base_async( ) return chunks except Exception as e: - perf.warning( - "[kb_search] connector=%s FAILED in %.3fs: %s", - connector, - time.perf_counter() - t_conn, - e, - ) + perf.warning("[kb_search] connector=%s FAILED: %s", connector, e) return [] t_gather = time.perf_counter() @@ -582,12 +833,24 @@ async def search_knowledge_base_async( output_budget = _compute_tool_output_budget(max_input_tokens) result = format_documents_for_context(deduplicated, max_chars=output_budget) + + if len(result) > output_budget: + perf.warning( + "[kb_search] output STILL exceeds budget after format (%d > %d), " + "hard truncation should have fired", + len(result), + output_budget, + ) + perf.info( - "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d space=%d", + "[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d " + "budget=%d max_input_tokens=%s space=%d", time.perf_counter() - t0, len(all_documents), len(deduplicated), len(result), + output_budget, + max_input_tokens, search_space_id, ) return result @@ -628,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel): """Input schema for the search_knowledge_base tool.""" query: str = Field( - description="The search query - be specific and include key terms" + description=( + "The search query - use specific natural language terms. " + "NEVER use wildcards like '*' or '**'; instead describe what you want " + "(e.g. 'recent meeting notes' or 'project architecture overview')." + ), ) top_k: int = Field( default=10, - description="Number of results to retrieve (default: 10)", + description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.", ) start_date: str | None = Field( default=None, @@ -695,6 +962,10 @@ Focus searches on these types for best results.""" Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question. IMPORTANT: +- Always craft specific, descriptive search queries using natural language keywords. + Good: "quarterly sales report Q3", "Python API authentication design". + Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results. +- Prefer multiple focused searches over a single broad one with high top_k. - If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below. - If `connectors_to_search` is omitted/empty, the system will search broadly. - Only connectors that are enabled/configured for this search space are available.{doc_types_info} @@ -710,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type # Capture for closure _available_connectors = available_connectors + _available_document_types = available_document_types async def _search_knowledge_base_impl( query: str, @@ -739,6 +1011,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type start_date=parsed_start, end_date=parsed_end, available_connectors=_available_connectors, + available_document_types=_available_document_types, max_input_tokens=max_input_tokens, ) diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index f36f0de13..99cb09b38 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -145,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ thread_id=deps["thread_id"], connector_service=deps.get("connector_service"), available_connectors=deps.get("available_connectors"), + available_document_types=deps.get("available_document_types"), ), requires=["search_space_id", "thread_id"], - # connector_service and available_connectors are optional — - # when missing, source_strategy="kb_search" degrades gracefully to "provided" + # connector_service, available_connectors, and available_document_types + # are optional — when missing, source_strategy="kb_search" degrades + # gracefully to "provided" ), # Link preview tool - fetches Open Graph metadata for URLs ToolDefinition( diff --git a/surfsense_backend/app/agents/new_chat/tools/report.py b/surfsense_backend/app/agents/new_chat/tools/report.py index 0896fea4b..5212c2c3b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/report.py +++ b/surfsense_backend/app/agents/new_chat/tools/report.py @@ -559,6 +559,7 @@ def create_generate_report_tool( thread_id: int | None = None, connector_service: ConnectorService | None = None, available_connectors: list[str] | None = None, + available_document_types: list[str] | None = None, ): """ Factory function to create the generate_report tool with injected dependencies. @@ -838,6 +839,7 @@ def create_generate_report_tool( connector_service=kb_connector_svc, top_k=10, available_connectors=available_connectors, + available_document_types=available_document_types, ) kb_results = await asyncio.gather( diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index 38ecba96c..4787e8147 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -3,6 +3,8 @@ from datetime import datetime from app.utils.perf import get_perf_logger +_MAX_FETCH_CHUNKS_PER_DOC = 30 + class ChucksHybridSearchRetriever: def __init__(self, db_session): @@ -346,8 +348,9 @@ class ChucksHybridSearchRetriever: if not doc_ids: return [] - # Fetch ALL chunks for selected documents in a single query so the final prompt can cite - # any chunk from those documents. + # Fetch chunks for selected documents. We cap per document to avoid + # loading hundreds of chunks for a single large file while still + # ensuring the chunks that matched the RRF query are always included. chunk_query = ( select(Chunk) .options(joinedload(Chunk.document)) @@ -357,7 +360,20 @@ class ChucksHybridSearchRetriever: .order_by(Chunk.document_id, Chunk.id) ) chunks_result = await self.db_session.execute(chunk_query) - all_chunks = chunks_result.scalars().all() + raw_chunks = chunks_result.scalars().all() + + matched_chunk_ids: set[int] = { + item["chunk_id"] for item in serialized_chunk_results + } + + doc_chunk_counts: dict[int, int] = {} + all_chunks: list = [] + for chunk in raw_chunks: + did = chunk.document_id + count = doc_chunk_counts.get(did, 0) + if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC: + all_chunks.append(chunk) + doc_chunk_counts[did] = count + 1 # Assemble final doc-grouped results in the same order as doc_ids doc_map: dict[int, dict] = { diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index f4daf8e26..69e97384f 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -3,6 +3,8 @@ from datetime import datetime from app.utils.perf import get_perf_logger +_MAX_FETCH_CHUNKS_PER_DOC = 30 + class DocumentHybridSearchRetriever: def __init__(self, db_session): @@ -279,7 +281,8 @@ class DocumentHybridSearchRetriever: # Collect document IDs for chunk fetching doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores] - # Fetch ALL chunks for these documents in a single query + # Fetch chunks for these documents, capped per document to avoid + # loading hundreds of chunks for a single large file. chunks_query = ( select(Chunk) .options(joinedload(Chunk.document)) @@ -287,7 +290,16 @@ class DocumentHybridSearchRetriever: .order_by(Chunk.document_id, Chunk.id) ) chunks_result = await self.db_session.execute(chunks_query) - chunks = chunks_result.scalars().all() + raw_chunks = chunks_result.scalars().all() + + doc_chunk_counts: dict[int, int] = {} + chunks: list = [] + for chunk in raw_chunks: + did = chunk.document_id + count = doc_chunk_counts.get(did, 0) + if count < _MAX_FETCH_CHUNKS_PER_DOC: + chunks.append(chunk) + doc_chunk_counts[did] = count + 1 # Assemble doc-grouped results doc_map: dict[int, dict] = { diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index 157e0bab5..0aa48eccd 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -224,6 +224,7 @@ class ConnectorService: top_k: int = 20, start_date: datetime | None = None, end_date: datetime | None = None, + query_embedding: list[float] | None = None, ) -> list[dict[str, Any]]: """ Perform combined search using both chunk-based and document-based hybrid search, @@ -260,14 +261,15 @@ class ConnectorService: # Get more results from each retriever for better fusion retriever_top_k = top_k * 2 - # Pre-compute the embedding once so both retrievers reuse it. - t_embed = time.perf_counter() - query_embedding = config.embedding_model_instance.embed(query_text) - perf.info( - "[connector_svc] _combined_rrf embedding in %.3fs type=%s", - time.perf_counter() - t_embed, - document_type, - ) + # Reuse caller-provided embedding or compute once for both retrievers. + if query_embedding is None: + t_embed = time.perf_counter() + query_embedding = config.embedding_model_instance.embed(query_text) + perf.info( + "[connector_svc] _combined_rrf embedding in %.3fs type=%s", + time.perf_counter() - t_embed, + document_type, + ) search_kwargs = { "query_text": query_text, diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 2465834f4..e8c0d2d47 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -159,26 +159,95 @@ class LLMRouterService: # Merge with provided settings final_settings = {**default_settings, **instance._router_settings} + # Build a "auto-large" fallback group with deployments whose context + # window exceeds the smallest deployment. This lets the router + # automatically fall back to a bigger-context model when gpt-4o (128K) + # hits ContextWindowExceededError. + full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list) + try: - instance._router = Router( - model_list=model_list, - routing_strategy=final_settings.get( + router_kwargs: dict[str, Any] = { + "model_list": full_model_list, + "routing_strategy": final_settings.get( "routing_strategy", "usage-based-routing" ), - num_retries=final_settings.get("num_retries", 3), - allowed_fails=final_settings.get("allowed_fails", 3), - cooldown_time=final_settings.get("cooldown_time", 60), - set_verbose=False, # Disable verbose logging in production - ) + "num_retries": final_settings.get("num_retries", 3), + "allowed_fails": final_settings.get("allowed_fails", 3), + "cooldown_time": final_settings.get("cooldown_time", 60), + "set_verbose": False, + } + if ctx_fallbacks: + router_kwargs["context_window_fallbacks"] = ctx_fallbacks + + instance._router = Router(**router_kwargs) instance._initialized = True logger.info( - f"LLM Router initialized with {len(model_list)} deployments, " - f"strategy: {final_settings.get('routing_strategy')}" + "LLM Router initialized with %d deployments, " + "strategy: %s, context_window_fallbacks: %s", + len(model_list), + final_settings.get("routing_strategy"), + ctx_fallbacks or "none", ) except Exception as e: logger.error(f"Failed to initialize LLM Router: {e}") instance._router = None + @classmethod + def _build_context_fallback_groups( + cls, model_list: list[dict] + ) -> tuple[list[dict], list[dict[str, list[str]]] | None]: + """Create an ``auto-large`` model group for context-window fallbacks. + + Uses ``litellm.get_model_info`` to discover the context window of each + deployment. Deployments whose ``max_input_tokens`` exceeds the smallest + window are duplicated into an ``auto-large`` group. The returned + fallback config tells the Router: on ``ContextWindowExceededError`` for + ``auto``, retry with ``auto-large``. + + Returns: + (full_model_list, context_window_fallbacks) — ``full_model_list`` + contains the original entries plus any ``auto-large`` duplicates. + ``context_window_fallbacks`` is ``None`` when every deployment has + the same context size (no useful fallback). + """ + from litellm import get_model_info + + ctx_map: dict[str, int] = {} + for dep in model_list: + params = dep.get("litellm_params", {}) + base_model = params.get("base_model") or params.get("model", "") + try: + info = get_model_info(base_model) + ctx = info.get("max_input_tokens") + if isinstance(ctx, int) and ctx > 0: + ctx_map[base_model] = ctx + except Exception: + continue + + if not ctx_map: + return model_list, None + + min_ctx = min(ctx_map.values()) + + large_deployments: list[dict] = [] + for dep in model_list: + params = dep.get("litellm_params", {}) + base_model = params.get("base_model") or params.get("model", "") + if ctx_map.get(base_model, 0) > min_ctx: + dup = {**dep, "model_name": "auto-large"} + large_deployments.append(dup) + + if not large_deployments: + return model_list, None + + logger.info( + "Context-window fallback: %d large-context deployments " + "(min_ctx=%d) added to 'auto-large' group", + len(large_deployments), + min_ctx, + ) + return model_list + large_deployments, [{"auto": ["auto-large"]}] + @classmethod def _config_to_deployment(cls, config: dict) -> dict | None: """ From dd3da2bc368c1550fc23764be66559198c2c42aa Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Sat, 28 Feb 2026 23:17:11 -0800 Subject: [PATCH 09/10] refactor: improve session management and cleanup in chat streaming - Added proper session closure to prevent connection leaks during streaming. - Implemented a fresh session for cleanup tasks to ensure data integrity. - Enhanced error handling during session operations to improve robustness. - Removed unnecessary session parameters from function signatures for clarity. --- .../app/routes/new_chat_routes.py | 52 ++++++----- .../app/tasks/chat/stream_new_chat.py | 87 +++++++++++++------ 2 files changed, 91 insertions(+), 48 deletions(-) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index c997cba68..8952907a0 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -31,6 +31,7 @@ from app.db import ( Permission, SearchSpace, User, + async_session_maker, get_async_session, ) from app.schemas.new_chat import ( @@ -1092,13 +1093,18 @@ async def handle_new_chat( # on searchspaces/documents for the entire duration of the stream. # expire_on_commit=False keeps loaded ORM attrs usable. await session.commit() + # Close the dependency session now so its connection returns to + # the pool before streaming begins. Without this, Starlette's + # BaseHTTPMiddleware cancels the scope on client disconnect and + # the dependency generator's __aexit__ never runs, orphaning the + # connection (the "Exception terminating connection" errors). + await session.close() return StreamingResponse( stream_new_chat( user_query=request.user_query, search_space_id=request.search_space_id, chat_id=request.chat_id, - session=session, user_id=str(user.id), llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, @@ -1323,6 +1329,7 @@ async def regenerate_response( # on searchspaces/documents for the entire duration of the stream. # expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable. await session.commit() + await session.close() # Create a wrapper generator that deletes messages only AFTER streaming succeeds # This prevents data loss if streaming fails (network error, LLM error, etc.) @@ -1333,7 +1340,6 @@ async def regenerate_response( user_query=user_query_to_use, search_space_id=request.search_space_id, chat_id=thread_id, - session=session, user_id=str(user.id), llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, @@ -1344,29 +1350,35 @@ async def regenerate_response( current_user_display_name=user.display_name or "A team member", ): yield chunk - # If we get here, streaming completed successfully streaming_completed = True finally: - # Only delete old messages if streaming completed successfully - # This ensures we don't lose data on streaming failures - if streaming_completed and messages_to_delete: + # Only delete old messages if streaming completed successfully. + # Uses a fresh session since stream_new_chat manages its own. + if streaming_completed and message_ids_to_delete: try: - for msg in messages_to_delete: - await session.delete(msg) - await session.commit() + async with async_session_maker() as cleanup_session: + for msg_id in message_ids_to_delete: + _res = await cleanup_session.execute( + select(NewChatMessage).filter( + NewChatMessage.id == msg_id + ) + ) + _msg = _res.scalars().first() + if _msg: + await cleanup_session.delete(_msg) + await cleanup_session.commit() - # Delete any public snapshots that contain the modified messages - from app.services.public_chat_service import ( - delete_affected_snapshots, - ) + from app.services.public_chat_service import ( + delete_affected_snapshots, + ) - await delete_affected_snapshots( - session, thread_id, message_ids_to_delete - ) + await delete_affected_snapshots( + cleanup_session, thread_id, message_ids_to_delete + ) except Exception as cleanup_error: - # Log but don't fail - the new messages are already streamed - print( - f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}" + _logger.warning( + "[regenerate] Failed to delete old messages: %s", + cleanup_error, ) # Return streaming response with checkpoint_id for rewinding @@ -1440,13 +1452,13 @@ async def resume_chat( # Release the read-transaction so we don't hold ACCESS SHARE locks # on searchspaces/documents for the entire duration of the stream. await session.commit() + await session.close() return StreamingResponse( stream_resume_chat( chat_id=thread_id, search_space_id=request.search_space_id, decisions=decisions, - session=session, user_id=str(user.id), llm_config_id=llm_config_id, thread_visibility=thread.visibility, diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index ae7001a40..34ea6ec82 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -10,6 +10,7 @@ Supports loading LLM configurations from: """ import asyncio +import contextlib import gc import json import logging @@ -20,9 +21,9 @@ from dataclasses import dataclass, field from typing import Any from uuid import UUID +import anyio from langchain_core.messages import HumanMessage from sqlalchemy import func -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload @@ -1012,7 +1013,6 @@ async def stream_new_chat( user_query: str, search_space_id: int, chat_id: int, - session: AsyncSession, user_id: str | None = None, llm_config_id: int = -1, mentioned_document_ids: list[int] | None = None, @@ -1028,11 +1028,13 @@ async def stream_new_chat( This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming. The chat_id is used as LangGraph's thread_id for memory/checkpointing. + The function creates and manages its own database session to guarantee proper + cleanup even when Starlette's middleware cancels the task on client disconnect. + Args: user_query: The user's query search_space_id: The search space ID chat_id: The chat ID (used as LangGraph thread_id for memory) - session: The database session user_id: The current user's UUID string (for memory tools and session state) llm_config_id: The LLM configuration ID (default: -1 for first global config) needs_history_bootstrap: If True, load message history from DB (for cloned chats) @@ -1048,6 +1050,7 @@ async def stream_new_chat( _t_total = time.perf_counter() log_system_snapshot("stream_new_chat_START") + session = async_session_maker() try: # Mark AI as responding to this user for live collaboration if user_id: @@ -1283,6 +1286,12 @@ async def stream_new_chat( # short-lived transactions (or use isolated sessions). await session.commit() + # Detach heavy ORM objects (documents with chunks, reports, etc.) + # from the session identity map now that we've extracted the data + # we need. This prevents them from accumulating in memory for the + # entire duration of LLM streaming (which can be several minutes). + session.expunge_all() + _perf_log.info( "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)", time.perf_counter() - _t_total, @@ -1459,23 +1468,35 @@ async def stream_new_chat( yield streaming_service.format_done() finally: - # Clear AI responding state for live collaboration. - # The original session may be broken (client disconnect / CancelledError - # can corrupt the underlying DB connection), so we try a rollback first - # and fall back to a fresh session if the original is unusable. - try: - await session.rollback() - await clear_ai_responding(session, chat_id) - except Exception: + # Shield the ENTIRE async cleanup from anyio cancel-scope + # cancellation. Starlette's BaseHTTPMiddleware uses anyio task + # groups; on client disconnect, it cancels the scope with + # level-triggered cancellation — every unshielded `await` inside + # the cancelled scope raises CancelledError immediately. Without + # this shield the very first `await` (session.rollback) would + # raise CancelledError, `except Exception` wouldn't catch it + # (CancelledError is a BaseException), and the rest of the + # finally block — including session.close() — would never run. + with anyio.CancelScope(shield=True): try: - async with async_session_maker() as fresh_session: - await clear_ai_responding(fresh_session, chat_id) + await session.rollback() + await clear_ai_responding(session, chat_id) except Exception: - logging.getLogger(__name__).warning( - "Failed to clear AI responding state for thread %s", chat_id - ) + try: + async with async_session_maker() as fresh_session: + await clear_ai_responding(fresh_session, chat_id) + except Exception: + logging.getLogger(__name__).warning( + "Failed to clear AI responding state for thread %s", chat_id + ) - _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + + with contextlib.suppress(Exception): + session.expunge_all() + + with contextlib.suppress(Exception): + await session.close() # Break circular refs held by the agent graph, tools, and LLM # wrappers so the GC can reclaim them in a single pass. @@ -1483,6 +1504,7 @@ async def stream_new_chat( mentioned_documents = mentioned_surfsense_docs = None recent_reports = langchain_messages = input_state = None stream_result = None + session = None collected = gc.collect(0) + gc.collect(1) + gc.collect(2) if collected: @@ -1498,7 +1520,6 @@ async def stream_resume_chat( chat_id: int, search_space_id: int, decisions: list[dict], - session: AsyncSession, user_id: str | None = None, llm_config_id: int = -1, thread_visibility: ChatVisibility | None = None, @@ -1507,6 +1528,7 @@ async def stream_resume_chat( stream_result = StreamResult() _t_total = time.perf_counter() + session = async_session_maker() try: if user_id: await set_ai_responding(session, chat_id, UUID(user_id)) @@ -1603,6 +1625,7 @@ async def stream_resume_chat( # Release the transaction before streaming (same rationale as stream_new_chat). await session.commit() + session.expunge_all() _perf_log.info( "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)", @@ -1666,22 +1689,30 @@ async def stream_resume_chat( yield streaming_service.format_done() finally: - try: - await session.rollback() - await clear_ai_responding(session, chat_id) - except Exception: + with anyio.CancelScope(shield=True): try: - async with async_session_maker() as fresh_session: - await clear_ai_responding(fresh_session, chat_id) + await session.rollback() + await clear_ai_responding(session, chat_id) except Exception: - logging.getLogger(__name__).warning( - "Failed to clear AI responding state for thread %s", chat_id - ) + try: + async with async_session_maker() as fresh_session: + await clear_ai_responding(fresh_session, chat_id) + except Exception: + logging.getLogger(__name__).warning( + "Failed to clear AI responding state for thread %s", chat_id + ) - _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + + with contextlib.suppress(Exception): + session.expunge_all() + + with contextlib.suppress(Exception): + await session.close() agent = llm = connector_service = sandbox_backend = None stream_result = None + session = None collected = gc.collect(0) + gc.collect(1) + gc.collect(2) if collected: From ecb0a25cc80c8db533cf0dfb7396589361ef2b27 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Sat, 28 Feb 2026 23:59:28 -0800 Subject: [PATCH 10/10] feat: enhance memory management and session handling in database operations - Introduced a shielded async session context manager to ensure safe session closure during cancellations. - Updated various database operations to utilize the new shielded session, preventing orphaned connections. - Added environment variables to optimize glibc memory management, improving overall application performance. - Implemented a function to trim the native heap, allowing for better memory reclamation on Linux systems. --- surfsense_backend/Dockerfile | 7 ++++++ .../agents/new_chat/tools/knowledge_base.py | 8 +++---- .../app/agents/new_chat/tools/report.py | 10 ++++---- surfsense_backend/app/db.py | 22 ++++++++++++++++++ .../app/routes/new_chat_routes.py | 4 ++-- .../app/tasks/chat/stream_new_chat.py | 19 ++++++++++----- surfsense_backend/app/utils/perf.py | 23 +++++++++++++++++++ 7 files changed, 76 insertions(+), 17 deletions(-) diff --git a/surfsense_backend/Dockerfile b/surfsense_backend/Dockerfile index 4f24d2b05..1222b36b6 100644 --- a/surfsense_backend/Dockerfile +++ b/surfsense_backend/Dockerfile @@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp ENV PYTHONPATH=/app ENV UVICORN_LOOP=asyncio +# Tune glibc malloc to return freed memory to the OS more aggressively. +# Without these, Python's gc.collect() frees objects but the underlying +# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation. +ENV MALLOC_MMAP_THRESHOLD_=65536 +ENV MALLOC_TRIM_THRESHOLD_=131072 +ENV MALLOC_MMAP_MAX_=65536 + # SERVICE_ROLE controls which process this container runs: # api – FastAPI backend only (runs migrations on startup) # worker – Celery worker only diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index 16cad80e5..f1d3d16b8 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -19,7 +19,7 @@ from langchain_core.tools import StructuredTool from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession -from app.db import async_session_maker +from app.db import shielded_async_session from app.services.connector_service import ConnectorService from app.utils.perf import get_perf_logger @@ -98,7 +98,7 @@ async def _browse_recent_documents( if end_date is not None: base_conditions.append(Document.updated_at <= end_date) - async with async_session_maker() as session: + async with shielded_async_session() as session: doc_query = ( select(Document) .options(joinedload(Document.search_space)) @@ -739,7 +739,7 @@ async def search_knowledge_base_async( try: t_conn = time.perf_counter() - async with semaphore, async_session_maker() as isolated_session: + async with semaphore, shielded_async_session() as isolated_session: svc = ConnectorService(isolated_session, search_space_id) _, chunks = await getattr(svc, method_name)(**kwargs) perf.info( @@ -756,7 +756,7 @@ async def search_knowledge_base_async( # --- Optimization 3: call _combined_rrf_search directly with shared embedding --- try: t_conn = time.perf_counter() - async with semaphore, async_session_maker() as isolated_session: + async with semaphore, shielded_async_session() as isolated_session: svc = ConnectorService(isolated_session, search_space_id) chunks = await svc._combined_rrf_search( query_text=query, diff --git a/surfsense_backend/app/agents/new_chat/tools/report.py b/surfsense_backend/app/agents/new_chat/tools/report.py index 5212c2c3b..fe5181f54 100644 --- a/surfsense_backend/app/agents/new_chat/tools/report.py +++ b/surfsense_backend/app/agents/new_chat/tools/report.py @@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.tools import tool -from app.db import Report, async_session_maker +from app.db import Report, shielded_async_session from app.services.connector_service import ConnectorService from app.services.llm_service import get_document_summary_llm @@ -717,7 +717,7 @@ def create_generate_report_tool( async def _save_failed_report(error_msg: str) -> int | None: """Persist a failed report row using a short-lived session.""" try: - async with async_session_maker() as session: + async with shielded_async_session() as session: failed_report = Report( title=topic, content=None, @@ -751,7 +751,7 @@ def create_generate_report_tool( # ── Phase 1: READ (short-lived session) ────────────────────── # Fetch parent report and LLM config, then close the session # so no DB connection is held during the long LLM call. - async with async_session_maker() as read_session: + async with shielded_async_session() as read_session: if parent_report_id: parent_report = await read_session.get(Report, parent_report_id) if parent_report: @@ -828,7 +828,7 @@ def create_generate_report_tool( # Run all queries in parallel, each with its own session async def _run_single_query(q: str) -> str: - async with async_session_maker() as kb_session: + async with shielded_async_session() as kb_session: kb_connector_svc = ConnectorService( kb_session, search_space_id ) @@ -1028,7 +1028,7 @@ def create_generate_report_tool( # ── Phase 3: WRITE (short-lived session) ───────────────────── # Save the report to the database, then close the session. - async with async_session_maker() as write_session: + async with shielded_async_session() as write_session: report = Report( title=topic, content=report_content, diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index ba926c9ad..510f64cc3 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1,7 +1,9 @@ from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from datetime import UTC, datetime from enum import StrEnum +import anyio from fastapi import Depends from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase from pgvector.sqlalchemy import Vector @@ -1867,6 +1869,26 @@ engine = create_async_engine( async_session_maker = async_sessionmaker(engine, expire_on_commit=False) +@asynccontextmanager +async def shielded_async_session(): + """Cancellation-safe async session context manager. + + Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel + scope when a client disconnects. A plain ``async with async_session_maker()`` + has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the + scope, orphaning the underlying database connection. + + This wrapper ensures ``session.close()`` always completes by running it + inside ``anyio.CancelScope(shield=True)``. + """ + session = async_session_maker() + try: + yield session + finally: + with anyio.CancelScope(shield=True): + await session.close() + + async def setup_indexes(): async with engine.begin() as conn: # Create indexes diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 8952907a0..e0d78696f 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -31,8 +31,8 @@ from app.db import ( Permission, SearchSpace, User, - async_session_maker, get_async_session, + shielded_async_session, ) from app.schemas.new_chat import ( NewChatMessageAppend, @@ -1356,7 +1356,7 @@ async def regenerate_response( # Uses a fresh session since stream_new_chat manages its own. if streaming_completed and message_ids_to_delete: try: - async with async_session_maker() as cleanup_session: + async with shielded_async_session() as cleanup_session: for msg_id in message_ids_to_delete: _res = await cleanup_session.execute( select(NewChatMessage).filter( diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 34ea6ec82..8d09ff387 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -49,6 +49,7 @@ from app.db import ( SearchSourceConnectorType, SurfsenseDocsDocument, async_session_maker, + shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.services.chat_session_state_service import ( @@ -58,7 +59,7 @@ from app.services.chat_session_state_service import ( from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.utils.content_utils import bootstrap_history_from_db -from app.utils.perf import get_perf_logger, log_system_snapshot +from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap _perf_log = get_perf_logger() @@ -1359,6 +1360,12 @@ async def stream_new_chat( items=initial_items, ) + # These ORM objects (with eagerly-loaded chunks) can be very large. + # They're only needed to build context strings already copied into + # final_query / langchain_messages — release them before streaming. + del mentioned_documents, mentioned_surfsense_docs, recent_reports + del langchain_messages, final_query + _t_stream_start = time.perf_counter() _first_event_logged = False async for sse in _stream_agent_events( @@ -1483,7 +1490,7 @@ async def stream_new_chat( await clear_ai_responding(session, chat_id) except Exception: try: - async with async_session_maker() as fresh_session: + async with shielded_async_session() as fresh_session: await clear_ai_responding(fresh_session, chat_id) except Exception: logging.getLogger(__name__).warning( @@ -1501,9 +1508,7 @@ async def stream_new_chat( # Break circular refs held by the agent graph, tools, and LLM # wrappers so the GC can reclaim them in a single pass. agent = llm = connector_service = sandbox_backend = None - mentioned_documents = mentioned_surfsense_docs = None - recent_reports = langchain_messages = input_state = None - stream_result = None + input_state = stream_result = None session = None collected = gc.collect(0) + gc.collect(1) + gc.collect(2) @@ -1513,6 +1518,7 @@ async def stream_new_chat( collected, chat_id, ) + trim_native_heap() log_system_snapshot("stream_new_chat_END") @@ -1695,7 +1701,7 @@ async def stream_resume_chat( await clear_ai_responding(session, chat_id) except Exception: try: - async with async_session_maker() as fresh_session: + async with shielded_async_session() as fresh_session: await clear_ai_responding(fresh_session, chat_id) except Exception: logging.getLogger(__name__).warning( @@ -1721,4 +1727,5 @@ async def stream_resume_chat( collected, chat_id, ) + trim_native_heap() log_system_snapshot("stream_resume_chat_END") diff --git a/surfsense_backend/app/utils/perf.py b/surfsense_backend/app/utils/perf.py index d6602d666..b2b26897c 100644 --- a/surfsense_backend/app/utils/perf.py +++ b/surfsense_backend/app/utils/perf.py @@ -149,3 +149,26 @@ def log_system_snapshot(label: str = "system_snapshot") -> None: snap["rss_delta_mb"], snap["rss_mb"], ) + + +def trim_native_heap() -> bool: + """Ask glibc to return free heap pages to the OS via ``malloc_trim(0)``. + + On Linux (glibc), ``free()`` does not release memory back to the OS if + it sits below the brk watermark. ``malloc_trim`` forces the allocator + to give back as many freed pages as possible. + + Returns True if trimming was performed, False otherwise (non-Linux or + libc unavailable). + """ + import ctypes + import sys + + if sys.platform != "linux": + return False + try: + libc = ctypes.CDLL("libc.so.6") + libc.malloc_trim(0) + return True + except (OSError, AttributeError): + return False