Merge remote-tracking branch 'upstream/dev' into refactor/upload-document-adapter-class

This commit is contained in:
Anish Sarkar 2026-03-01 22:35:17 +05:30
commit 6d00b0debf
47 changed files with 1880 additions and 578 deletions

View file

@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp
ENV PYTHONPATH=/app
ENV UVICORN_LOOP=asyncio
# Tune glibc malloc to return freed memory to the OS more aggressively.
# Without these, Python's gc.collect() frees objects but the underlying
# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation.
ENV MALLOC_MMAP_THRESHOLD_=65536
ENV MALLOC_TRIM_THRESHOLD_=131072
ENV MALLOC_MMAP_MAX_=65536
# SERVICE_ROLE controls which process this container runs:
# api FastAPI backend only (runs migrations on startup)
# worker Celery worker only

View file

@ -28,8 +28,9 @@ from app.agents.new_chat.system_prompt import (
from app.agents.new_chat.tools.registry import build_tools_async
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
_perf_log = logging.getLogger("surfsense.perf")
_perf_log = get_perf_logger()
# =============================================================================
# Connector Type Mapping

View file

@ -22,6 +22,7 @@ from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
)
@ -389,7 +390,7 @@ def create_chat_litellm_from_agent_config(
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
return ChatLiteLLMRouter()
return get_auto_mode_llm()
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None

View file

@ -58,6 +58,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
_daytona_client: Daytona | None = None
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
_SANDBOX_CACHE_MAX_SIZE = 20
THREAD_LABEL_KEY = "surfsense_thread"
@ -144,6 +145,12 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
return cached
sandbox = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
oldest_key = next(iter(_sandbox_cache))
_sandbox_cache.pop(oldest_key, None)
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
return sandbox

View file

@ -10,6 +10,8 @@ This module provides:
import asyncio
import json
import re
import time
from datetime import datetime
from typing import Any
@ -17,8 +19,152 @@ from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from app.db import shielded_async_session
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
# Connectors that call external live-search APIs (no local DB / embedding needed).
# These are never filtered by available_document_types.
_LIVE_SEARCH_CONNECTORS: set[str] = {
"TAVILY_API",
"SEARXNG_API",
"LINKUP_API",
"BAIDU_SEARCH_API",
}
# Patterns that indicate the query has no meaningful search signal.
# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
# of '*' is random noise, so both keyword and semantic search degrade to
# arbitrary ordering — large documents (many chunks) dominate by chance.
_DEGENERATE_QUERY_RE = re.compile(
r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
)
# Max chunks per document when doing a recency-based browse instead of
# a real search. We want breadth (many docs) over depth (many chunks).
_BROWSE_MAX_CHUNKS_PER_DOC = 5
def _is_degenerate_query(query: str) -> bool:
"""Return True when the query carries no meaningful search signal.
Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
strings, and single-character non-word tokens. These queries cause
both keyword search (empty tsquery) and semantic search (meaningless
embedding) to return effectively random results.
"""
stripped = query.strip()
if not stripped:
return True
return bool(_DEGENERATE_QUERY_RE.match(stripped))
async def _browse_recent_documents(
search_space_id: int,
document_type: str | None,
top_k: int,
start_date: datetime | None,
end_date: datetime | None,
) -> list[dict[str, Any]]:
"""Return the most-recent documents (recency-ordered, no search ranking).
Used as a fallback when the search query is degenerate (e.g. ``*``) and
semantic / keyword search would produce arbitrary results. Returns
document-grouped dicts in the same shape as ``_combined_rrf_search``
so the rest of the pipeline works unchanged.
"""
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, DocumentType
perf = get_perf_logger()
t0 = time.perf_counter()
base_conditions = [Document.search_space_id == search_space_id]
if document_type is not None:
if isinstance(document_type, str):
try:
doc_type_enum = DocumentType[document_type]
base_conditions.append(Document.document_type == doc_type_enum)
except KeyError:
return []
else:
base_conditions.append(Document.document_type == document_type)
if start_date is not None:
base_conditions.append(Document.updated_at >= start_date)
if end_date is not None:
base_conditions.append(Document.updated_at <= end_date)
async with shielded_async_session() as session:
doc_query = (
select(Document)
.options(joinedload(Document.search_space))
.where(*base_conditions)
.order_by(Document.updated_at.desc())
.limit(top_k)
)
result = await session.execute(doc_query)
documents = result.scalars().unique().all()
if not documents:
return []
doc_ids = [d.id for d in documents]
chunk_query = (
select(Chunk)
.where(Chunk.document_id.in_(doc_ids))
.order_by(Chunk.document_id, Chunk.id)
)
chunk_result = await session.execute(chunk_query)
raw_chunks = chunk_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _BROWSE_MAX_CHUNKS_PER_DOC:
doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
doc_chunk_counts[did] = count + 1
results: list[dict[str, Any]] = []
for doc in documents:
chunks_list = doc_chunks.get(doc.id, [])
results.append(
{
"document_id": doc.id,
"content": "\n\n".join(
c["content"] for c in chunks_list if c.get("content")
),
"score": 0.0,
"chunks": chunks_list,
"document": {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
},
"source": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
}
)
perf.info(
"[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0,
len(results),
search_space_id,
document_type,
)
return results
# =============================================================================
# Connector Constants and Normalization
@ -182,9 +328,23 @@ _CHARS_PER_TOKEN = 4
# Hard-floor / ceiling so the budget is always sensible regardless of what
# the model reports.
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens
_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
_MAX_CHUNK_CHARS = 8_000
# Rank-adaptive per-document budget allocation.
# Top-ranked (most relevant) documents get a larger share of the budget so
# we pack as much high-quality context as possible.
#
# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
#
# Examples (128K budget, 8K chunk cap):
# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
# rank 2 → 24% → 3 chunks |
_TOP_DOC_BUDGET_FRACTION = 0.40
_RANK_DECAY = 0.35
_MIN_CHUNKS_PER_DOC = 3
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
"""Derive a character budget from the model's context window.
@ -206,18 +366,24 @@ def format_documents_for_context(
*,
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
max_chunk_chars: int = _MAX_CHUNK_CHARS,
max_chunks_per_doc: int = 0,
) -> str:
"""
Format retrieved documents into a readable context string for the LLM.
Documents are added in order (highest relevance first) until the character
budget is reached. Individual chunks are capped at ``max_chunk_chars`` so
a single oversized chunk cannot monopolize the output.
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
each document is limited to a dynamically computed chunk cap so a single
large document cannot monopolize the output while still maximising the use
of available context space.
Args:
documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
auto-compute per document using a rank-adaptive formula so
higher-ranked documents receive more chunks.
Returns:
Formatted string with document contents and metadata
@ -340,7 +506,23 @@ def format_documents_for_context(
"<document_content>",
]
for ch in g["chunks"]:
# Rank-adaptive per-document chunk cap: top results get more chunks.
if max_chunks_per_doc > 0:
chunks_allowed = max_chunks_per_doc
else:
doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
max_doc_chars = int(max_chars * doc_fraction)
xml_overhead = 500
chunks_allowed = max(
(max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
_MIN_CHUNKS_PER_DOC,
)
chunks = g["chunks"]
if len(chunks) > chunks_allowed:
chunks = chunks[:chunks_allowed]
for ch in chunks:
ch_content = ch["content"]
if max_chunk_chars and len(ch_content) > max_chunk_chars:
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
@ -357,9 +539,11 @@ def format_documents_for_context(
doc_xml = "\n".join(doc_lines)
doc_len = len(doc_xml)
# Always include at least the first document; afterwards enforce budget.
if doc_idx > 0 and total_chars + doc_len > max_chars:
if total_chars + doc_len > max_chars:
remaining = total_docs - doc_idx
if doc_idx == 0:
parts.append(doc_xml)
total_chars += doc_len
parts.append(
f"<!-- Output truncated: {remaining} more document(s) omitted "
f"(budget {max_chars} chars). Refine your query or reduce top_k "
@ -370,7 +554,15 @@ def format_documents_for_context(
parts.append(doc_xml)
total_chars += doc_len
return "\n".join(parts).strip()
result = "\n".join(parts).strip()
# Hard safety net: if the result is still over budget (e.g. a single massive
# first document), forcibly truncate with a closing comment.
if len(result) > max_chars:
truncation_msg = "\n<!-- ...output forcibly truncated to fit context window -->"
result = result[: max_chars - len(truncation_msg)] + truncation_msg
return result
# =============================================================================
@ -388,6 +580,7 @@ async def search_knowledge_base_async(
start_date: datetime | None = None,
end_date: datetime | None = None,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> str:
"""
@ -406,12 +599,18 @@ async def search_knowledge_base_async(
end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched.
available_document_types: Optional list of document types that actually have indexed
data. When provided, local connectors whose document type is
absent are skipped entirely (no embedding / DB round-trip).
max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits.
Returns:
Formatted string with search results
"""
perf = get_perf_logger()
t0 = time.perf_counter()
all_documents: list[dict[str, Any]] = []
# Resolve date range (default last 2 years)
@ -424,88 +623,169 @@ async def search_knowledge_base_async(
connectors = _normalize_connectors(connectors_to_search, available_connectors)
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"YOUTUBE_VIDEO": ("search_youtube", True, True, {}),
"EXTENSION": ("search_extension", True, True, {}),
"CRAWLED_URL": ("search_crawled_urls", True, True, {}),
"FILE": ("search_files", True, True, {}),
"SLACK_CONNECTOR": ("search_slack", True, True, {}),
"TEAMS_CONNECTOR": ("search_teams", True, True, {}),
"NOTION_CONNECTOR": ("search_notion", True, True, {}),
"GITHUB_CONNECTOR": ("search_github", True, True, {}),
"LINEAR_CONNECTOR": ("search_linear", True, True, {}),
# --- Optimization 1: skip local connectors that have zero indexed documents ---
if available_document_types:
doc_types_set = set(available_document_types)
before_count = len(connectors)
connectors = [
c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set
]
skipped = before_count - len(connectors)
if skipped:
perf.info(
"[kb_search] skipped %d empty connectors (had %d, now %d)",
skipped,
before_count,
len(connectors),
)
perf.info(
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
len(connectors),
connectors[:5],
search_space_id,
top_k,
)
# --- Fast-path: degenerate queries (*, **, empty, etc.) ---
# Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
# yields an empty tsquery, so both retrieval signals are useless.
# Fall back to a recency-ordered browse that returns diverse results.
if _is_degenerate_query(query):
perf.info(
"[kb_search] degenerate query %r detected - falling back to recency browse",
query,
)
local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS]
if not local_connectors:
local_connectors = [None] # type: ignore[list-item]
browse_results = await asyncio.gather(
*[
_browse_recent_documents(
search_space_id=search_space_id,
document_type=c,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
for c in local_connectors
]
)
for docs in browse_results:
all_documents.extend(docs)
# Skip dedup + formatting below (browse already returns unique docs)
# but still cap output budget.
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(
all_documents,
max_chars=output_budget,
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
)
perf.info(
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
"budget=%d space=%d",
time.perf_counter() - t0,
len(all_documents),
len(result),
output_budget,
search_space_id,
)
return result
# Specs for live-search connectors (external APIs, no local DB/embedding).
live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"TAVILY_API": ("search_tavily", False, True, {}),
"SEARXNG_API": ("search_searxng", False, True, {}),
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
"DISCORD_CONNECTOR": ("search_discord", True, True, {}),
"JIRA_CONNECTOR": ("search_jira", True, True, {}),
"GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
"AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
"GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
"GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
"CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
"CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
"LUMA_CONNECTOR": ("search_luma", True, True, {}),
"ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
"NOTE": ("search_notes", True, True, {}),
"BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
"CIRCLEBACK": ("search_circleback", True, True, {}),
"OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
# Composio connectors
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
"search_composio_google_drive",
True,
True,
{},
),
"COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
"search_composio_google_calendar",
True,
True,
{},
),
}
# Keep a conservative cap to avoid overloading DB/external services.
# --- Optimization 2: compute the query embedding once, share across all local searches ---
precomputed_embedding: list[float] | None = None
has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors)
if has_local_connectors:
from app.config import config as app_config
t_embed = time.perf_counter()
precomputed_embedding = app_config.embedding_model_instance.embed(query)
perf.info(
"[kb_search] shared embedding computed in %.3fs",
time.perf_counter() - t_embed,
)
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
spec = connector_specs.get(connector)
if spec is None:
return []
is_live = connector in _LIVE_SEARCH_CONNECTORS
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
kwargs: dict[str, Any] = {
"user_query": query,
"search_space_id": search_space_id,
**extra_kwargs,
}
if includes_top_k:
kwargs["top_k"] = top_k
if includes_date_range:
kwargs["start_date"] = resolved_start_date
kwargs["end_date"] = resolved_end_date
if is_live:
spec = live_connector_specs.get(connector)
if spec is None:
return []
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
kwargs: dict[str, Any] = {
"user_query": query,
"search_space_id": search_space_id,
**extra_kwargs,
}
if includes_top_k:
kwargs["top_k"] = top_k
if includes_date_range:
kwargs["start_date"] = resolved_start_date
kwargs["end_date"] = resolved_end_date
try:
t_conn = time.perf_counter()
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
_, chunks = await getattr(svc, method_name)(**kwargs)
perf.info(
"[kb_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t_conn,
)
return chunks
except Exception as e:
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
return []
# --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
try:
# Use isolated session per connector. Shared AsyncSession cannot safely
# run concurrent DB operations.
async with semaphore, async_session_maker() as isolated_session:
isolated_connector_service = ConnectorService(
isolated_session, search_space_id
t_conn = time.perf_counter()
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
chunks = await svc._combined_rrf_search(
query_text=query,
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
query_embedding=precomputed_embedding,
)
perf.info(
"[kb_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t_conn,
)
connector_method = getattr(isolated_connector_service, method_name)
_, chunks = await connector_method(**kwargs)
return chunks
except Exception as e:
print(f"Error searching connector {connector}: {e}")
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
return []
t_gather = time.perf_counter()
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
perf.info(
"[kb_search] all connectors gathered in %.3fs",
time.perf_counter() - t_gather,
)
for chunks in connector_results:
all_documents.extend(chunks)
@ -552,7 +832,28 @@ async def search_knowledge_base_async(
deduplicated.append(doc)
output_budget = _compute_tool_output_budget(max_input_tokens)
return format_documents_for_context(deduplicated, max_chars=output_budget)
result = format_documents_for_context(deduplicated, max_chars=output_budget)
if len(result) > output_budget:
perf.warning(
"[kb_search] output STILL exceeds budget after format (%d > %d), "
"hard truncation should have fired",
len(result),
output_budget,
)
perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
"budget=%d max_input_tokens=%s space=%d",
time.perf_counter() - t0,
len(all_documents),
len(deduplicated),
len(result),
output_budget,
max_input_tokens,
search_space_id,
)
return result
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
@ -590,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel):
"""Input schema for the search_knowledge_base tool."""
query: str = Field(
description="The search query - be specific and include key terms"
description=(
"The search query - use specific natural language terms. "
"NEVER use wildcards like '*' or '**'; instead describe what you want "
"(e.g. 'recent meeting notes' or 'project architecture overview')."
),
)
top_k: int = Field(
default=10,
description="Number of results to retrieve (default: 10)",
description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
)
start_date: str | None = Field(
default=None,
@ -657,6 +962,10 @@ Focus searches on these types for best results."""
Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question.
IMPORTANT:
- Always craft specific, descriptive search queries using natural language keywords.
Good: "quarterly sales report Q3", "Python API authentication design".
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
- Prefer multiple focused searches over a single broad one with high top_k.
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
- If `connectors_to_search` is omitted/empty, the system will search broadly.
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
@ -672,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
# Capture for closure
_available_connectors = available_connectors
_available_document_types = available_document_types
async def _search_knowledge_base_impl(
query: str,
@ -701,6 +1011,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
start_date=parsed_start,
end_date=parsed_end,
available_connectors=_available_connectors,
available_document_types=_available_document_types,
max_input_tokens=max_input_tokens,
)

View file

@ -27,9 +27,24 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_MCP_CACHE_MAX_SIZE = 50
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
def _evict_expired_mcp_cache() -> None:
"""Remove expired entries from the MCP tools cache to prevent unbounded growth."""
now = time.monotonic()
expired = [
k
for k, (ts, _) in _mcp_tools_cache.items()
if now - ts >= _MCP_CACHE_TTL_SECONDS
]
for k in expired:
del _mcp_tools_cache[k]
if expired:
logger.debug("Evicted %d expired MCP cache entries", len(expired))
def _create_dynamic_input_model_from_schema(
tool_name: str,
input_schema: dict[str, Any],
@ -392,6 +407,8 @@ async def load_mcp_tools(
List of LangChain StructuredTool instances
"""
_evict_expired_mcp_cache()
now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id)
if cached is not None:
@ -445,6 +462,11 @@ async def load_mcp_tools(
)
_mcp_tools_cache[search_space_id] = (now, tools)
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
del _mcp_tools_cache[oldest_key]
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools

View file

@ -145,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
thread_id=deps["thread_id"],
connector_service=deps.get("connector_service"),
available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
),
requires=["search_space_id", "thread_id"],
# connector_service and available_connectors are optional —
# when missing, source_strategy="kb_search" degrades gracefully to "provided"
# connector_service, available_connectors, and available_document_types
# are optional — when missing, source_strategy="kb_search" degrades
# gracefully to "provided"
),
# Link preview tool - fetches Open Graph metadata for URLs
ToolDefinition(

View file

@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from app.db import Report, async_session_maker
from app.db import Report, shielded_async_session
from app.services.connector_service import ConnectorService
from app.services.llm_service import get_document_summary_llm
@ -559,6 +559,7 @@ def create_generate_report_tool(
thread_id: int | None = None,
connector_service: ConnectorService | None = None,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
):
"""
Factory function to create the generate_report tool with injected dependencies.
@ -716,7 +717,7 @@ def create_generate_report_tool(
async def _save_failed_report(error_msg: str) -> int | None:
"""Persist a failed report row using a short-lived session."""
try:
async with async_session_maker() as session:
async with shielded_async_session() as session:
failed_report = Report(
title=topic,
content=None,
@ -750,7 +751,7 @@ def create_generate_report_tool(
# ── Phase 1: READ (short-lived session) ──────────────────────
# Fetch parent report and LLM config, then close the session
# so no DB connection is held during the long LLM call.
async with async_session_maker() as read_session:
async with shielded_async_session() as read_session:
if parent_report_id:
parent_report = await read_session.get(Report, parent_report_id)
if parent_report:
@ -827,7 +828,7 @@ def create_generate_report_tool(
# Run all queries in parallel, each with its own session
async def _run_single_query(q: str) -> str:
async with async_session_maker() as kb_session:
async with shielded_async_session() as kb_session:
kb_connector_svc = ConnectorService(
kb_session, search_space_id
)
@ -838,6 +839,7 @@ def create_generate_report_tool(
connector_service=kb_connector_svc,
top_k=10,
available_connectors=available_connectors,
available_document_types=available_document_types,
)
kb_results = await asyncio.gather(
@ -1026,7 +1028,7 @@ def create_generate_report_tool(
# ── Phase 3: WRITE (short-lived session) ─────────────────────
# Save the report to the database, then close the session.
async with async_session_maker() as write_session:
async with shielded_async_session() as write_session:
report = Report(
title=topic,
content=report_content,

View file

@ -1,4 +1,5 @@
import asyncio
import gc
import logging
import time
from collections import defaultdict
@ -15,6 +16,9 @@ from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from app.agents.new_chat.checkpointer import (
@ -28,6 +32,7 @@ from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
from app.utils.perf import get_perf_logger, log_system_snapshot
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
@ -99,22 +104,24 @@ def _check_rate_limit_memory(
now = time.monotonic()
with _memory_lock:
# Evict timestamps outside the current window
_memory_rate_limits[key] = [
t for t in _memory_rate_limits[key] if now - t < window_seconds
]
timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]
if len(_memory_rate_limits[key]) >= max_requests:
if not timestamps:
_memory_rate_limits.pop(key, None)
else:
_memory_rate_limits[key] = timestamps
if len(timestamps) >= max_requests:
rate_limit_logger.warning(
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)"
f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
)
raise HTTPException(
status_code=429,
detail="RATE_LIMIT_EXCEEDED",
)
_memory_rate_limits[key].append(now)
_memory_rate_limits[key] = [*timestamps, now]
def _check_rate_limit(
@ -206,18 +213,16 @@ def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
# Enable slow-callback detection (set PERF_DEBUG=1 env var to activate)
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
# sooner (default 700/10/10 → 700/10/5). This reduces peak RSS
# with minimal CPU overhead.
gc.set_threshold(700, 10, 5)
_enable_slow_callback_logging(threshold_sec=0.5)
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
# Setup LangGraph checkpointer tables for conversation persistence
await setup_checkpointer_tables()
# Initialize LLM Router for Auto mode load balancing
initialize_llm_router()
# Initialize Image Generation Router for Auto mode load balancing
initialize_image_gen_router()
# Seed Surfsense documentation (with timeout so a slow embedding API
# doesn't block startup indefinitely and make the container unresponsive)
try:
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
except TimeoutError:
@ -225,8 +230,11 @@ async def lifespan(app: FastAPI):
"Surfsense docs seeding timed out after 120s — skipping. "
"Docs will be indexed on the next restart."
)
log_system_snapshot("startup_complete")
yield
# Cleanup: close checkpointer connection on shutdown
await close_checkpointer()
@ -244,6 +252,63 @@ app = FastAPI(lifespan=lifespan)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# ---------------------------------------------------------------------------
# Request-level performance middleware
# ---------------------------------------------------------------------------
# Logs wall-clock time, method, path, and status for every request so we can
# spot slow endpoints in production logs.
_PERF_SLOW_REQUEST_THRESHOLD = float(
__import__("os").environ.get("PERF_SLOW_REQUEST_MS", "2000")
)
class RequestPerfMiddleware(BaseHTTPMiddleware):
"""Middleware that logs per-request wall-clock time.
- ALL requests are logged at DEBUG level.
- Requests exceeding PERF_SLOW_REQUEST_MS (default 2000ms) are logged at
WARNING level with a system snapshot so we can correlate slow responses
with CPU/memory usage at that moment.
"""
async def dispatch(
self, request: StarletteRequest, call_next: RequestResponseEndpoint
) -> StarletteResponse:
perf = get_perf_logger()
t0 = time.perf_counter()
response = await call_next(request)
elapsed_ms = (time.perf_counter() - t0) * 1000
path = request.url.path
method = request.method
status = response.status_code
perf.debug(
"[request] %s %s -> %d in %.1fms",
method,
path,
status,
elapsed_ms,
)
if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD:
perf.warning(
"[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)",
method,
path,
status,
elapsed_ms,
_PERF_SLOW_REQUEST_THRESHOLD,
)
log_system_snapshot("slow_request")
return response
app.add_middleware(RequestPerfMiddleware)
# Add SlowAPI middleware for automatic rate limiting
# Uses Starlette BaseHTTPMiddleware (not the raw ASGI variant) to avoid
# corrupting StreamingResponse — SlowAPIASGIMiddleware re-sends

View file

@ -1,7 +1,9 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from enum import StrEnum
import anyio
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
from pgvector.sqlalchemy import Vector
@ -1856,10 +1858,37 @@ class RefreshToken(Base, TimestampMixin):
return not self.is_expired and not self.is_revoked
engine = create_async_engine(DATABASE_URL)
engine = create_async_engine(
DATABASE_URL,
pool_size=30,
max_overflow=150,
pool_recycle=1800,
pool_pre_ping=True,
pool_timeout=30,
)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
@asynccontextmanager
async def shielded_async_session():
"""Cancellation-safe async session context manager.
Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel
scope when a client disconnects. A plain ``async with async_session_maker()``
has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the
scope, orphaning the underlying database connection.
This wrapper ensures ``session.close()`` always completes by running it
inside ``anyio.CancelScope(shield=True)``.
"""
session = async_session_maker()
try:
yield session
finally:
with anyio.CancelScope(shield=True):
await session.close()
async def setup_indexes():
async with engine.begin() as conn:
# Create indexes

View file

@ -1,4 +1,5 @@
import contextlib
import time
from datetime import UTC, datetime
from sqlalchemy import delete, select
@ -44,6 +45,7 @@ from app.indexing_pipeline.pipeline_logger import (
log_retryable_llm_error,
log_unexpected_error,
)
from app.utils.perf import get_perf_logger
class IndexingPipelineService:
@ -58,6 +60,9 @@ class IndexingPipelineService:
"""
Persist new documents and detect changes, returning only those that need indexing.
"""
perf = get_perf_logger()
t0 = time.perf_counter()
documents = []
seen_hashes: set[str] = set()
batch_ctx = PipelineLogContext(
@ -140,11 +145,14 @@ class IndexingPipelineService:
try:
await self.session.commit()
perf.info(
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
time.perf_counter() - t0,
len(connector_docs),
len(documents),
)
return documents
except IntegrityError:
# A concurrent worker committed a document with the same content_hash
# or unique_identifier_hash between our check and our INSERT.
# The document already exists — roll back and let the next sync run handle it.
log_race_condition(batch_ctx)
await self.session.rollback()
return []
@ -165,26 +173,41 @@ class IndexingPipelineService:
unique_id=connector_doc.unique_id,
doc_id=document.id,
)
perf = get_perf_logger()
t_index = time.perf_counter()
try:
log_index_started(ctx)
document.status = DocumentStatus.processing()
await self.session.commit()
t_step = time.perf_counter()
if connector_doc.should_summarize and llm is not None:
content = await summarize_document(
connector_doc.source_markdown, llm, connector_doc.metadata
)
perf.info(
"[indexing] summarize_document doc=%d in %.3fs",
document.id,
time.perf_counter() - t_step,
)
elif connector_doc.should_summarize and connector_doc.fallback_summary:
content = connector_doc.fallback_summary
else:
content = connector_doc.source_markdown
t_step = time.perf_counter()
embedding = embed_text(content)
perf.debug(
"[indexing] embed_text (summary) doc=%d in %.3fs",
document.id,
time.perf_counter() - t_step,
)
await self.session.execute(
delete(Chunk).where(Chunk.document_id == document.id)
)
t_step = time.perf_counter()
chunks = [
Chunk(content=text, embedding=embed_text(text))
for text in chunk_text(
@ -192,6 +215,12 @@ class IndexingPipelineService:
use_code_chunker=connector_doc.should_use_code_chunker,
)
]
perf.info(
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
document.id,
len(chunks),
time.perf_counter() - t_step,
)
document.content = content
document.embedding = embedding
@ -199,6 +228,12 @@ class IndexingPipelineService:
document.updated_at = datetime.now(UTC)
document.status = DocumentStatus.ready()
await self.session.commit()
perf.info(
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
document.id,
len(chunks),
time.perf_counter() - t_index,
)
log_index_success(ctx, chunk_count=len(chunks))
except RETRYABLE_LLM_ERRORS as e:

View file

@ -1,5 +1,10 @@
import time
from datetime import datetime
from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
class ChucksHybridSearchRetriever:
def __init__(self, db_session):
@ -38,9 +43,17 @@ class ChucksHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document
perf = get_perf_logger()
t0 = time.perf_counter()
# Get embedding for the query
embedding_model = config.embedding_model_instance
t_embed = time.perf_counter()
query_embedding = embedding_model.embed(query_text)
perf.debug(
"[chunk_search] vector_search embedding in %.3fs",
time.perf_counter() - t_embed,
)
# Build the query filtered by search space
query = (
@ -60,8 +73,16 @@ class ChucksHybridSearchRetriever:
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
# Execute the query
t_db = time.perf_counter()
result = await self.db_session.execute(query)
chunks = result.scalars().all()
perf.info(
"[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d",
time.perf_counter() - t_db,
len(chunks),
time.perf_counter() - t0,
search_space_id,
)
return chunks
@ -91,6 +112,9 @@ class ChucksHybridSearchRetriever:
from app.db import Chunk, Document
perf = get_perf_logger()
t0 = time.perf_counter()
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector("english", Chunk.content)
tsquery = func.plainto_tsquery("english", query_text)
@ -118,6 +142,12 @@ class ChucksHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
chunks = result.scalars().all()
perf.info(
"[chunk_search] full_text_search in %.3fs results=%d space=%d",
time.perf_counter() - t0,
len(chunks),
search_space_id,
)
return chunks
@ -129,6 +159,7 @@ class ChucksHybridSearchRetriever:
document_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
query_embedding: list | None = None,
) -> list:
"""
Hybrid search that returns **documents** (not individual chunks).
@ -143,6 +174,7 @@ class ChucksHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
query_embedding: Pre-computed embedding vector. If None, will be computed here.
Returns:
List of dictionaries containing document data and relevance scores. Each dict contains:
@ -157,9 +189,17 @@ class ChucksHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document, DocumentType
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
perf = get_perf_logger()
t0 = time.perf_counter()
if query_embedding is None:
embedding_model = config.embedding_model_instance
t_embed = time.perf_counter()
query_embedding = embedding_model.embed(query_text)
perf.debug(
"[chunk_search] hybrid_search embedding in %.3fs",
time.perf_counter() - t_embed,
)
# RRF constants
k = 60
@ -254,9 +294,17 @@ class ChucksHybridSearchRetriever:
.limit(top_k)
)
# Execute the query
# Execute the RRF query
t_rrf = time.perf_counter()
result = await self.db_session.execute(final_query)
chunks_with_scores = result.all()
perf.info(
"[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s",
time.perf_counter() - t_rrf,
len(chunks_with_scores),
search_space_id,
document_type,
)
# If no results were found, return an empty list
if not chunks_with_scores:
@ -300,8 +348,9 @@ class ChucksHybridSearchRetriever:
if not doc_ids:
return []
# Fetch ALL chunks for selected documents in a single query so the final prompt can cite
# any chunk from those documents.
# Fetch chunks for selected documents. We cap per document to avoid
# loading hundreds of chunks for a single large file while still
# ensuring the chunks that matched the RRF query are always included.
chunk_query = (
select(Chunk)
.options(joinedload(Chunk.document))
@ -311,7 +360,20 @@ class ChucksHybridSearchRetriever:
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunk_query)
all_chunks = chunks_result.scalars().all()
raw_chunks = chunks_result.scalars().all()
matched_chunk_ids: set[int] = {
item["chunk_id"] for item in serialized_chunk_results
}
doc_chunk_counts: dict[int, int] = {}
all_chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
all_chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# Assemble final doc-grouped results in the same order as doc_ids
doc_map: dict[int, dict] = {
@ -354,4 +416,11 @@ class ChucksHybridSearchRetriever:
)
final_docs.append(entry)
perf.info(
"[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0,
len(final_docs),
search_space_id,
document_type,
)
return final_docs

View file

@ -1,5 +1,10 @@
import time
from datetime import datetime
from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
class DocumentHybridSearchRetriever:
def __init__(self, db_session):
@ -38,6 +43,9 @@ class DocumentHybridSearchRetriever:
from app.config import config
from app.db import Document
perf = get_perf_logger()
t0 = time.perf_counter()
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
@ -63,6 +71,12 @@ class DocumentHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
documents = result.scalars().all()
perf.info(
"[doc_search] vector_search in %.3fs results=%d space=%d",
time.perf_counter() - t0,
len(documents),
search_space_id,
)
return documents
@ -92,6 +106,9 @@ class DocumentHybridSearchRetriever:
from app.db import Document
perf = get_perf_logger()
t0 = time.perf_counter()
# Create tsvector and tsquery for PostgreSQL full-text search
tsvector = func.to_tsvector("english", Document.content)
tsquery = func.plainto_tsquery("english", query_text)
@ -118,6 +135,12 @@ class DocumentHybridSearchRetriever:
# Execute the query
result = await self.db_session.execute(query)
documents = result.scalars().all()
perf.info(
"[doc_search] full_text_search in %.3fs results=%d space=%d",
time.perf_counter() - t0,
len(documents),
search_space_id,
)
return documents
@ -129,6 +152,7 @@ class DocumentHybridSearchRetriever:
document_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
query_embedding: list | None = None,
) -> list:
"""
Hybrid search that returns **documents** (not individual chunks).
@ -143,7 +167,7 @@ class DocumentHybridSearchRetriever:
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
start_date: Optional start date for filtering documents by updated_at
end_date: Optional end date for filtering documents by updated_at
query_embedding: Pre-computed embedding vector. If None, will be computed here.
"""
from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload
@ -151,9 +175,12 @@ class DocumentHybridSearchRetriever:
from app.config import config
from app.db import Chunk, Document, DocumentType
# Get embedding for the query
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
perf = get_perf_logger()
t0 = time.perf_counter()
if query_embedding is None:
embedding_model = config.embedding_model_instance
query_embedding = embedding_model.embed(query_text)
# RRF constants
k = 60
@ -254,7 +281,8 @@ class DocumentHybridSearchRetriever:
# Collect document IDs for chunk fetching
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
# Fetch ALL chunks for these documents in a single query
# Fetch chunks for these documents, capped per document to avoid
# loading hundreds of chunks for a single large file.
chunks_query = (
select(Chunk)
.options(joinedload(Chunk.document))
@ -262,7 +290,16 @@ class DocumentHybridSearchRetriever:
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunks_query)
chunks = chunks_result.scalars().all()
raw_chunks = chunks_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _MAX_FETCH_CHUNKS_PER_DOC:
chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# Assemble doc-grouped results
doc_map: dict[int, dict] = {
@ -303,4 +340,11 @@ class DocumentHybridSearchRetriever:
)
final_docs.append(entry)
perf.info(
"[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0,
len(final_docs),
search_space_id,
document_type,
)
return final_docs

View file

@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session
from app.schemas.chat_comments import (
CommentBatchRequest,
CommentBatchResponse,
CommentCreateRequest,
CommentListResponse,
CommentReplyResponse,
@ -19,6 +21,7 @@ from app.services.chat_comments_service import (
create_reply,
delete_comment,
get_comments_for_message,
get_comments_for_messages_batch,
get_user_mentions,
update_comment,
)
@ -27,6 +30,16 @@ from app.users import current_active_user
router = APIRouter()
@router.post("/messages/comments/batch", response_model=CommentBatchResponse)
async def batch_list_comments(
request: CommentBatchRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Batch-fetch comments for multiple messages in one request."""
return await get_comments_for_messages_batch(session, request.message_ids, user)
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
async def list_comments(
message_id: int,

View file

@ -133,6 +133,8 @@ async def create_documents_file_upload(
Requires DOCUMENTS_CREATE permission.
"""
import os
import tempfile
from datetime import datetime
from app.db import DocumentStatus
@ -143,7 +145,6 @@ async def create_documents_file_upload(
from app.utils.document_converters import generate_unique_identifier_hash
try:
# Check permission
await check_permission(
session,
user,
@ -179,69 +180,64 @@ async def create_documents_file_upload(
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
created_documents: list[Document] = []
files_to_process: list[
tuple[Document, str, str]
] = [] # (document, temp_path, filename)
skipped_duplicates = 0
duplicate_document_ids: list[int] = []
actual_total_size = 0
# ===== Read all files concurrently to avoid blocking the event loop =====
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
"""Read upload content and write to temp file off the event loop."""
content = await file.read()
file_size = len(content)
filename = file.filename or "unknown"
# ===== PHASE 1: Create pending documents for all files =====
# This makes ALL documents visible in the UI immediately with pending status
for file in files:
try:
import os
import tempfile
# Save file to temp location
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(file.filename or "")[1]
) as temp_file:
temp_path = temp_file.name
content = await file.read()
file_size = len(content)
if file_size > MAX_FILE_SIZE_BYTES:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
actual_total_size += file_size
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
with open(temp_path, "wb") as f:
f.write(content)
# Generate unique identifier for deduplication check
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.FILE, file.filename or "unknown", search_space_id
if file_size > MAX_FILE_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"File '{filename}' ({file_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
)
def _write_temp() -> str:
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(filename)[1]
) as tmp:
tmp.write(content)
return tmp.name
temp_path = await asyncio.to_thread(_write_temp)
return temp_path, filename, file_size
saved_files = await asyncio.gather(*(_read_and_save(f) for f in files))
actual_total_size = sum(size for _, _, size in saved_files)
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
for temp_path, _, _ in saved_files:
os.unlink(temp_path)
raise HTTPException(
status_code=413,
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
)
# ===== PHASE 1: Create pending documents for all files =====
created_documents: list[Document] = []
files_to_process: list[tuple[Document, str, str]] = []
skipped_duplicates = 0
duplicate_document_ids: list[int] = []
for temp_path, filename, file_size in saved_files:
try:
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.FILE, filename, search_space_id
)
# Check if document already exists (by unique identifier)
existing = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
if existing:
if DocumentStatus.is_state(existing.status, DocumentStatus.READY):
# True duplicate — content already indexed, skip
os.unlink(temp_path)
skipped_duplicates += 1
duplicate_document_ids.append(existing.id)
continue
# Existing document is stuck (failed/pending/processing)
# Reset it to pending and re-dispatch for processing
existing.status = DocumentStatus.pending()
existing.content = "Processing..."
existing.document_metadata = {
@ -251,50 +247,45 @@ async def create_documents_file_upload(
}
existing.updated_at = get_current_timestamp()
created_documents.append(existing)
files_to_process.append(
(existing, temp_path, file.filename or "unknown")
)
files_to_process.append((existing, temp_path, filename))
continue
# Create pending document (visible immediately in UI via ElectricSQL)
document = Document(
search_space_id=search_space_id,
title=file.filename or "Uploaded File",
title=filename if filename != "unknown" else "Uploaded File",
document_type=DocumentType.FILE,
document_metadata={
"FILE_NAME": file.filename,
"FILE_NAME": filename,
"file_size": file_size,
"upload_time": datetime.now().isoformat(),
},
content="Processing...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary, updated when ready
content="Processing...",
content_hash=unique_identifier_hash,
unique_identifier_hash=unique_identifier_hash,
embedding=None,
status=DocumentStatus.pending(), # Shows "pending" in UI
status=DocumentStatus.pending(),
updated_at=get_current_timestamp(),
created_by_id=str(user.id),
)
session.add(document)
created_documents.append(document)
files_to_process.append(
(document, temp_path, file.filename or "unknown")
)
files_to_process.append((document, temp_path, filename))
except HTTPException:
raise
except Exception as e:
os.unlink(temp_path)
raise HTTPException(
status_code=422,
detail=f"Failed to process file {file.filename}: {e!s}",
detail=f"Failed to process file {filename}: {e!s}",
) from e
# Commit all pending documents - they appear in UI immediately via ElectricSQL
if created_documents:
await session.commit()
# Refresh to get generated IDs
for doc in created_documents:
await session.refresh(doc)
# ===== PHASE 2: Dispatch tasks for each file =====
# Each task will update document status: pending → processing → ready/failed
for document, temp_path, filename in files_to_process:
await dispatcher.dispatch_file_processing(
document_id=document.id,

View file

@ -32,6 +32,7 @@ from app.db import (
SearchSpace,
User,
get_async_session,
shielded_async_session,
)
from app.schemas.new_chat import (
NewChatMessageAppend,
@ -1092,13 +1093,18 @@ async def handle_new_chat(
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs usable.
await session.commit()
# Close the dependency session now so its connection returns to
# the pool before streaming begins. Without this, Starlette's
# BaseHTTPMiddleware cancels the scope on client disconnect and
# the dependency generator's __aexit__ never runs, orphaning the
# connection (the "Exception terminating connection" errors).
await session.close()
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
@ -1323,6 +1329,7 @@ async def regenerate_response(
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
await session.commit()
await session.close()
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
# This prevents data loss if streaming fails (network error, LLM error, etc.)
@ -1333,7 +1340,6 @@ async def regenerate_response(
user_query=user_query_to_use,
search_space_id=request.search_space_id,
chat_id=thread_id,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
@ -1344,29 +1350,35 @@ async def regenerate_response(
current_user_display_name=user.display_name or "A team member",
):
yield chunk
# If we get here, streaming completed successfully
streaming_completed = True
finally:
# Only delete old messages if streaming completed successfully
# This ensures we don't lose data on streaming failures
if streaming_completed and messages_to_delete:
# Only delete old messages if streaming completed successfully.
# Uses a fresh session since stream_new_chat manages its own.
if streaming_completed and message_ids_to_delete:
try:
for msg in messages_to_delete:
await session.delete(msg)
await session.commit()
async with shielded_async_session() as cleanup_session:
for msg_id in message_ids_to_delete:
_res = await cleanup_session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == msg_id
)
)
_msg = _res.scalars().first()
if _msg:
await cleanup_session.delete(_msg)
await cleanup_session.commit()
# Delete any public snapshots that contain the modified messages
from app.services.public_chat_service import (
delete_affected_snapshots,
)
from app.services.public_chat_service import (
delete_affected_snapshots,
)
await delete_affected_snapshots(
session, thread_id, message_ids_to_delete
)
await delete_affected_snapshots(
cleanup_session, thread_id, message_ids_to_delete
)
except Exception as cleanup_error:
# Log but don't fail - the new messages are already streamed
print(
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
_logger.warning(
"[regenerate] Failed to delete old messages: %s",
cleanup_error,
)
# Return streaming response with checkpoint_id for rewinding
@ -1440,13 +1452,13 @@ async def resume_chat(
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
await session.commit()
await session.close()
return StreamingResponse(
stream_resume_chat(
chat_id=thread_id,
search_space_id=request.search_space_id,
decisions=decisions,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,

View file

@ -87,6 +87,18 @@ class CommentListResponse(BaseModel):
total_count: int
class CommentBatchRequest(BaseModel):
"""Request for batch-fetching comments for multiple messages."""
message_ids: list[int] = Field(..., min_length=1, max_length=200)
class CommentBatchResponse(BaseModel):
"""Batch response keyed by message_id."""
comments_by_message: dict[int, CommentListResponse]
# =============================================================================
# Mention Schemas
# =============================================================================

View file

@ -22,6 +22,7 @@ from app.db import (
)
from app.schemas.chat_comments import (
AuthorResponse,
CommentBatchResponse,
CommentListResponse,
CommentReplyResponse,
CommentResponse,
@ -264,6 +265,146 @@ async def get_comments_for_message(
)
async def get_comments_for_messages_batch(
session: AsyncSession,
message_ids: list[int],
user: User,
) -> CommentBatchResponse:
"""
Batch-fetch comments for multiple messages in a single DB round-trip.
Validates that all messages exist and belong to search spaces the user
can read comments in, then loads all comments with eager-loaded authors
and replies.
"""
if not message_ids:
return CommentBatchResponse(comments_by_message={})
unique_ids = list(set(message_ids))
result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.thread))
.filter(NewChatMessage.id.in_(unique_ids))
)
messages = result.scalars().all()
msg_map = {m.id: m for m in messages}
search_space_ids = {m.thread.search_space_id for m in messages}
permissions_cache: dict[int, set] = {}
for ss_id in search_space_ids:
await check_permission(
session,
user,
ss_id,
Permission.COMMENTS_READ.value,
"You don't have permission to read comments in this search space",
)
permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id)
result = await session.execute(
select(ChatComment)
.options(
selectinload(ChatComment.author),
selectinload(ChatComment.replies).selectinload(ChatComment.author),
)
.filter(
ChatComment.message_id.in_(unique_ids),
ChatComment.parent_id.is_(None),
)
.order_by(ChatComment.created_at)
)
top_level_comments = result.scalars().all()
all_mentioned_uuids: set[UUID] = set()
for comment in top_level_comments:
all_mentioned_uuids.update(parse_mentions(comment.content))
for reply in comment.replies:
all_mentioned_uuids.update(parse_mentions(reply.content))
user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids}
for comment in top_level_comments:
comments_by_msg.setdefault(comment.message_id, []).append(comment)
comments_by_message: dict[int, CommentListResponse] = {}
for mid in unique_ids:
msg = msg_map.get(mid)
if msg is None:
comments_by_message[mid] = CommentListResponse(comments=[], total_count=0)
continue
ss_id = msg.thread.search_space_id
user_perms = permissions_cache.get(ss_id, set())
can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value)
comment_responses = []
for comment in comments_by_msg.get(mid, []):
author = None
if comment.author:
author = AuthorResponse(
id=comment.author.id,
display_name=comment.author.display_name,
avatar_url=comment.author.avatar_url,
email=comment.author.email,
)
replies = []
for reply in sorted(comment.replies, key=lambda r: r.created_at):
reply_author = None
if reply.author:
reply_author = AuthorResponse(
id=reply.author.id,
display_name=reply.author.display_name,
avatar_url=reply.author.avatar_url,
email=reply.author.email,
)
is_reply_author = (
reply.author_id == user.id if reply.author_id else False
)
replies.append(
CommentReplyResponse(
id=reply.id,
content=reply.content,
content_rendered=render_mentions(reply.content, user_names),
author=reply_author,
created_at=reply.created_at,
updated_at=reply.updated_at,
is_edited=reply.updated_at > reply.created_at,
can_edit=is_reply_author,
can_delete=is_reply_author or can_delete_any,
)
)
is_comment_author = (
comment.author_id == user.id if comment.author_id else False
)
comment_responses.append(
CommentResponse(
id=comment.id,
message_id=comment.message_id,
content=comment.content,
content_rendered=render_mentions(comment.content, user_names),
author=author,
created_at=comment.created_at,
updated_at=comment.updated_at,
is_edited=comment.updated_at > comment.created_at,
can_edit=is_comment_author,
can_delete=is_comment_author or can_delete_any,
reply_count=len(replies),
replies=replies,
)
)
comments_by_message[mid] = CommentListResponse(
comments=comment_responses,
total_count=len(comment_responses),
)
return CommentBatchResponse(comments_by_message=comments_by_message)
async def create_comment(
session: AsyncSession,
message_id: int,

View file

@ -1,4 +1,5 @@
import asyncio
import time
from datetime import datetime
from typing import Any
from urllib.parse import urljoin
@ -15,9 +16,11 @@ from app.db import (
Document,
SearchSourceConnector,
SearchSourceConnectorType,
async_session_maker,
)
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever
from app.utils.perf import get_perf_logger
class ConnectorService:
@ -221,6 +224,7 @@ class ConnectorService:
top_k: int = 20,
start_date: datetime | None = None,
end_date: datetime | None = None,
query_embedding: list[float] | None = None,
) -> list[dict[str, Any]]:
"""
Perform combined search using both chunk-based and document-based hybrid search,
@ -246,34 +250,60 @@ class ConnectorService:
Returns:
List of combined and deduplicated document results
"""
from app.config import config
perf = get_perf_logger()
t0 = time.perf_counter()
# RRF constant
k = 60
# Get more results from each retriever for better fusion
retriever_top_k = top_k * 2
# IMPORTANT:
# These retrievers share the same AsyncSession. AsyncSession does not permit
# concurrent awaits that require DB IO on the same session/connection.
# Running these in parallel can raise:
# "This session is provisioning a new connection; concurrent operations are not permitted"
#
# So we run them sequentially.
chunk_results = await self.chunk_retriever.hybrid_search(
query_text=query_text,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=document_type,
start_date=start_date,
end_date=end_date,
# Reuse caller-provided embedding or compute once for both retrievers.
if query_embedding is None:
t_embed = time.perf_counter()
query_embedding = config.embedding_model_instance.embed(query_text)
perf.info(
"[connector_svc] _combined_rrf embedding in %.3fs type=%s",
time.perf_counter() - t_embed,
document_type,
)
search_kwargs = {
"query_text": query_text,
"top_k": retriever_top_k,
"search_space_id": search_space_id,
"document_type": document_type,
"start_date": start_date,
"end_date": end_date,
"query_embedding": query_embedding,
}
# Run chunk and document retrievers in parallel using separate DB sessions
# so they don't contend on a shared AsyncSession connection.
async def _run_chunk_search() -> list[dict[str, Any]]:
async with async_session_maker() as session:
retriever = ChucksHybridSearchRetriever(session)
return await retriever.hybrid_search(**search_kwargs)
async def _run_doc_search() -> list[dict[str, Any]]:
async with async_session_maker() as session:
retriever = DocumentHybridSearchRetriever(session)
return await retriever.hybrid_search(**search_kwargs)
t_parallel = time.perf_counter()
chunk_results, doc_results = await asyncio.gather(
_run_chunk_search(), _run_doc_search()
)
doc_results = await self.document_retriever.hybrid_search(
query_text=query_text,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=document_type,
start_date=start_date,
end_date=end_date,
perf.info(
"[connector_svc] _combined_rrf parallel retrievers in %.3fs "
"chunk_results=%d doc_results=%d type=%s",
time.perf_counter() - t_parallel,
len(chunk_results),
len(doc_results),
document_type,
)
# Helper to extract document_id from our doc-grouped result
@ -335,6 +365,13 @@ class ConnectorService:
result["chunks"] = doc_data[did]["chunks"]
combined_results.append(result)
perf.info(
"[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d",
time.perf_counter() - t0,
len(combined_results),
document_type,
search_space_id,
)
return combined_results
def _get_doc_url(self, metadata: dict[str, Any]) -> str:

View file

@ -13,8 +13,10 @@ synchronous ChatLiteLLM-like interface and async methods.
import logging
import re
import time
from typing import Any
import litellm
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import BaseChatModel
@ -26,6 +28,11 @@ from litellm.exceptions import (
ContextWindowExceededError,
)
from app.utils.perf import get_perf_logger
litellm.json_logs = False
litellm.store_audit_logs = False
logger = logging.getLogger(__name__)
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
@ -152,26 +159,95 @@ class LLMRouterService:
# Merge with provided settings
final_settings = {**default_settings, **instance._router_settings}
# Build a "auto-large" fallback group with deployments whose context
# window exceeds the smallest deployment. This lets the router
# automatically fall back to a bigger-context model when gpt-4o (128K)
# hits ContextWindowExceededError.
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
try:
instance._router = Router(
model_list=model_list,
routing_strategy=final_settings.get(
router_kwargs: dict[str, Any] = {
"model_list": full_model_list,
"routing_strategy": final_settings.get(
"routing_strategy", "usage-based-routing"
),
num_retries=final_settings.get("num_retries", 3),
allowed_fails=final_settings.get("allowed_fails", 3),
cooldown_time=final_settings.get("cooldown_time", 60),
set_verbose=False, # Disable verbose logging in production
)
"num_retries": final_settings.get("num_retries", 3),
"allowed_fails": final_settings.get("allowed_fails", 3),
"cooldown_time": final_settings.get("cooldown_time", 60),
"set_verbose": False,
}
if ctx_fallbacks:
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
instance._router = Router(**router_kwargs)
instance._initialized = True
logger.info(
f"LLM Router initialized with {len(model_list)} deployments, "
f"strategy: {final_settings.get('routing_strategy')}"
"LLM Router initialized with %d deployments, "
"strategy: %s, context_window_fallbacks: %s",
len(model_list),
final_settings.get("routing_strategy"),
ctx_fallbacks or "none",
)
except Exception as e:
logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None
@classmethod
def _build_context_fallback_groups(
cls, model_list: list[dict]
) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
"""Create an ``auto-large`` model group for context-window fallbacks.
Uses ``litellm.get_model_info`` to discover the context window of each
deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
window are duplicated into an ``auto-large`` group. The returned
fallback config tells the Router: on ``ContextWindowExceededError`` for
``auto``, retry with ``auto-large``.
Returns:
(full_model_list, context_window_fallbacks) ``full_model_list``
contains the original entries plus any ``auto-large`` duplicates.
``context_window_fallbacks`` is ``None`` when every deployment has
the same context size (no useful fallback).
"""
from litellm import get_model_info
ctx_map: dict[str, int] = {}
for dep in model_list:
params = dep.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0:
ctx_map[base_model] = ctx
except Exception:
continue
if not ctx_map:
return model_list, None
min_ctx = min(ctx_map.values())
large_deployments: list[dict] = []
for dep in model_list:
params = dep.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
if ctx_map.get(base_model, 0) > min_ctx:
dup = {**dep, "model_name": "auto-large"}
large_deployments.append(dup)
if not large_deployments:
return model_list, None
logger.info(
"Context-window fallback: %d large-context deployments "
"(min_ctx=%d) added to 'auto-large' group",
len(large_deployments),
min_ctx,
)
return model_list + large_deployments, [{"auto": ["auto-large"]}]
@classmethod
def _config_to_deployment(cls, config: dict) -> dict | None:
"""
@ -247,6 +323,48 @@ class LLMRouterService:
return len(instance._model_list)
_cached_context_profile: dict | None = None
_cached_context_profile_computed: bool = False
# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
def _get_cached_context_profile(router: Router) -> dict | None:
"""Compute and cache the min context profile across all router deployments.
Called once on first ChatLiteLLMRouter creation; subsequent calls return
the cached value. This avoids calling litellm.get_model_info() for every
deployment on every request.
"""
global _cached_context_profile, _cached_context_profile_computed
if _cached_context_profile_computed:
return _cached_context_profile
from litellm import get_model_info
min_ctx: int | None = None
for deployment in router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
min_ctx = ctx
except Exception:
continue
if min_ctx is not None:
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
_cached_context_profile = {"max_input_tokens": min_ctx}
else:
_cached_context_profile = None
_cached_context_profile_computed = True
return _cached_context_profile
class ChatLiteLLMRouter(BaseChatModel):
"""
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
@ -257,6 +375,10 @@ class ChatLiteLLMRouter(BaseChatModel):
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
window across all router deployments so that deepagents
SummarizationMiddleware can use fraction-based triggers.
**Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
directly instances without bound tools are cached per streaming flag to
avoid per-request re-initialization overhead and memory growth.
"""
# Use model_config for Pydantic v2 compatibility
@ -278,14 +400,6 @@ class ChatLiteLLMRouter(BaseChatModel):
tool_choice: str | dict | None = None,
**kwargs,
):
"""
Initialize the ChatLiteLLMRouter.
Args:
router: LiteLLM Router instance. If None, uses the global singleton.
bound_tools: Pre-bound tools for tool calling
tool_choice: Tool choice configuration
"""
try:
super().__init__(**kwargs)
resolved_router = router or LLMRouterService.get_router()
@ -297,51 +411,20 @@ class ChatLiteLLMRouter(BaseChatModel):
"LLM Router not initialized. Call LLMRouterService.initialize() first."
)
# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
computed_profile = self._compute_min_context_profile()
computed_profile = _get_cached_context_profile(self._router)
if computed_profile is not None:
object.__setattr__(self, "profile", computed_profile)
logger.info(
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
logger.debug(
"ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
LLMRouterService.get_model_count(),
self.streaming,
bound_tools is not None,
)
except Exception as e:
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
raise
def _compute_min_context_profile(self) -> dict | None:
"""Derive a profile dict with max_input_tokens from router deployments.
Uses litellm.get_model_info to look up each deployment's context window
and picks the *minimum* so that summarization triggers before ANY model
in the pool overflows.
"""
from litellm import get_model_info
if not self._router:
return None
min_ctx: int | None = None
for deployment in self._router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if (
isinstance(ctx, int)
and ctx > 0
and (min_ctx is None or ctx < min_ctx)
):
min_ctx = ctx
except Exception:
continue
if min_ctx is not None:
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
return {"max_input_tokens": min_ctx}
return None
@property
def _llm_type(self) -> str:
return "litellm-router"
@ -410,6 +493,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
@ -428,12 +515,30 @@ class ChatLiteLLMRouter(BaseChatModel):
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
elapsed = time.perf_counter() - t0
perf.info(
"[llm_router] _generate completed msgs=%d tools=%d in %.3fs",
msg_count,
len(self._bound_tools) if self._bound_tools else 0,
elapsed,
)
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
generation = ChatGeneration(message=message)
@ -453,6 +558,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
@ -471,12 +580,30 @@ class ChatLiteLLMRouter(BaseChatModel):
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
elapsed = time.perf_counter() - t0
perf.info(
"[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs",
msg_count,
len(self._bound_tools) if self._bound_tools else 0,
elapsed,
)
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
generation = ChatGeneration(message=message)
@ -541,6 +668,10 @@ class ChatLiteLLMRouter(BaseChatModel):
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
formatted_messages = self._convert_messages(messages)
# Add tools if bound
@ -559,20 +690,52 @@ class ChatLiteLLMRouter(BaseChatModel):
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
# Yield chunks asynchronously
t_first_chunk = time.perf_counter()
perf.info(
"[llm_router] _astream connection established msgs=%d in %.3fs",
msg_count,
t_first_chunk - t0,
)
chunk_count = 0
first_chunk_logged = False
async for chunk in response:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
chunk_msg = self._convert_delta_to_chunk(delta)
if chunk_msg:
chunk_count += 1
if not first_chunk_logged:
perf.info(
"[llm_router] _astream first chunk in %.3fs (total %.3fs from start)",
time.perf_counter() - t_first_chunk,
time.perf_counter() - t0,
)
first_chunk_logged = True
yield ChatGenerationChunk(message=chunk_msg)
perf.info(
"[llm_router] _astream completed chunks=%d total=%.3fs",
chunk_count,
time.perf_counter() - t0,
)
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
"""Convert LangChain messages to OpenAI format."""
from langchain_core.messages import (
@ -687,19 +850,28 @@ class ChatLiteLLMRouter(BaseChatModel):
return None
def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
"""
Get a ChatLiteLLMRouter instance for auto mode.
def get_auto_mode_llm(
*,
streaming: bool = True,
) -> ChatLiteLLMRouter | None:
"""Return a cached ChatLiteLLMRouter for auto mode.
Returns:
ChatLiteLLMRouter instance or None if router not initialized
Base (no tools) instances are cached per ``streaming`` flag so we
avoid re-constructing them on every request. ``bind_tools()`` still
returns a fresh instance because bound tools differ per agent.
"""
if not LLMRouterService.is_initialized():
logger.warning("LLM Router not initialized for auto mode")
return None
cached = _router_instance_cache.get(streaming)
if cached is not None:
return cached
try:
return ChatLiteLLMRouter()
instance = ChatLiteLLMRouter(streaming=streaming)
_router_instance_cache[streaming] = instance
return instance
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None

View file

@ -12,12 +12,20 @@ from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
)
# Configure litellm to automatically drop unsupported parameters
litellm.drop_params = True
# Memory controls: prevent unbounded internal accumulation
litellm.telemetry = False
litellm.cache = None
litellm.success_callback = []
litellm.failure_callback = []
litellm.input_callback = []
logger = logging.getLogger(__name__)
@ -221,7 +229,7 @@ async def get_search_space_llm_instance(
logger.debug(
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
)
return ChatLiteLLMRouter(disable_streaming=disable_streaming)
return get_auto_mode_llm(streaming=not disable_streaming)
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None

View file

@ -1 +1,28 @@
"""Celery tasks package."""
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config
_celery_engine = None
_celery_session_maker = None
def get_celery_session_maker() -> async_sessionmaker:
"""Return a shared async session maker for Celery tasks.
A single NullPool engine is created per worker process and reused
across all task invocations to avoid leaking engine objects.
"""
global _celery_engine, _celery_session_maker
if _celery_session_maker is None:
_celery_engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
_celery_session_maker = async_sessionmaker(
_celery_engine, expire_on_commit=False
)
return _celery_session_maker

View file

@ -3,11 +3,8 @@
import logging
import traceback
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@ -42,20 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
)
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="index_slack_messages", bind=True)
def index_slack_messages_task(
self,

View file

@ -4,30 +4,18 @@ import logging
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import selectinload
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="reindex_document", bind=True)
def reindex_document_task(self, document_id: int, user_id: str):
"""

View file

@ -5,13 +5,11 @@ import logging
import os
from uuid import UUID
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.document_processors import (
add_extension_received_document,
add_youtube_video_document,
@ -91,20 +89,6 @@ async def _run_heartbeat_loop(notification_id: int):
pass # Normal cancellation when task completes
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="process_extension_document", bind=True)
def process_extension_document_task(
self, individual_document_dict, search_space_id: int, user_id: str

View file

@ -5,14 +5,13 @@ import logging
import sys
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config
from app.db import Podcast, PodcastStatus
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@ -25,20 +24,6 @@ if sys.platform.startswith("win"):
)
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================

View file

@ -3,28 +3,16 @@
import logging
from datetime import UTC, datetime
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
from app.tasks.celery_tasks import get_celery_session_maker
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="check_periodic_schedules")
def check_periodic_schedules_task():
"""

View file

@ -29,20 +29,17 @@ from datetime import UTC, datetime
import redis
from sqlalchemy import and_, or_, text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.future import select
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Document, DocumentStatus, Notification
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
# Redis client for checking heartbeats
_redis_client: redis.Redis | None = None
# Error messages shown to users when tasks are interrupted
STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry."
STALE_PROCESSING_ERROR_MESSAGE = "Syncing was interrupted unexpectedly. Please retry."
@ -60,16 +57,6 @@ def _get_heartbeat_key(notification_id: int) -> str:
return f"indexing:heartbeat:{notification_id}"
def get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="cleanup_stale_indexing_notifications")
def cleanup_stale_indexing_notifications_task():
"""

View file

@ -10,6 +10,8 @@ Supports loading LLM configurations from:
"""
import asyncio
import contextlib
import gc
import json
import logging
import re
@ -19,9 +21,9 @@ from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
import anyio
from langchain_core.messages import HumanMessage
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
@ -47,6 +49,7 @@ from app.db import (
SearchSourceConnectorType,
SurfsenseDocsDocument,
async_session_maker,
shielded_async_session,
)
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.services.chat_session_state_service import (
@ -56,14 +59,9 @@ from app.services.chat_session_state_service import (
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
if not _perf_log.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
_perf_log.addHandler(_h)
_perf_log.propagate = False
_perf_log = get_perf_logger()
_background_tasks: set[asyncio.Task] = set()
@ -1016,7 +1014,6 @@ async def stream_new_chat(
user_query: str,
search_space_id: int,
chat_id: int,
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
@ -1032,11 +1029,13 @@ async def stream_new_chat(
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
The function creates and manages its own database session to guarantee proper
cleanup even when Starlette's middleware cancels the task on client disconnect.
Args:
user_query: The user's query
search_space_id: The search space ID
chat_id: The chat ID (used as LangGraph thread_id for memory)
session: The database session
user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config)
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
@ -1050,7 +1049,9 @@ async def stream_new_chat(
streaming_service = VercelStreamingService()
stream_result = StreamResult()
_t_total = time.perf_counter()
log_system_snapshot("stream_new_chat_START")
session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
if user_id:
@ -1286,6 +1287,12 @@ async def stream_new_chat(
# short-lived transactions (or use isolated sessions).
await session.commit()
# Detach heavy ORM objects (documents with chunks, reports, etc.)
# from the session identity map now that we've extracted the data
# we need. This prevents them from accumulating in memory for the
# entire duration of LLM streaming (which can be several minutes).
session.expunge_all()
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
@ -1353,6 +1360,12 @@ async def stream_new_chat(
items=initial_items,
)
# These ORM objects (with eagerly-loaded chunks) can be very large.
# They're only needed to build context strings already copied into
# final_query / langchain_messages — release them before streaming.
del mentioned_documents, mentioned_surfsense_docs, recent_reports
del langchain_messages, final_query
_t_stream_start = time.perf_counter()
_first_event_logged = False
async for sse in _stream_agent_events(
@ -1382,6 +1395,7 @@ async def stream_new_chat(
time.perf_counter() - _t_stream_start,
chat_id,
)
log_system_snapshot("stream_new_chat_END")
if stream_result.is_interrupted:
yield streaming_service.format_finish_step()
@ -1461,30 +1475,57 @@ async def stream_new_chat(
yield streaming_service.format_done()
finally:
# Clear AI responding state for live collaboration.
# The original session may be broken (client disconnect / CancelledError
# can corrupt the underlying DB connection), so we try a rollback first
# and fall back to a fresh session if the original is unusable.
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
# Shield the ENTIRE async cleanup from anyio cancel-scope
# cancellation. Starlette's BaseHTTPMiddleware uses anyio task
# groups; on client disconnect, it cancels the scope with
# level-triggered cancellation — every unshielded `await` inside
# the cancelled scope raises CancelledError immediately. Without
# this shield the very first `await` (session.rollback) would
# raise CancelledError, `except Exception` wouldn't catch it
# (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
try:
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = sandbox_backend = None
input_state = stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
_perf_log.info(
"[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
collected,
chat_id,
)
trim_native_heap()
log_system_snapshot("stream_new_chat_END")
async def stream_resume_chat(
chat_id: int,
search_space_id: int,
decisions: list[dict],
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None,
@ -1493,6 +1534,7 @@ async def stream_resume_chat(
stream_result = StreamResult()
_t_total = time.perf_counter()
session = async_session_maker()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
@ -1589,6 +1631,7 @@ async def stream_resume_chat(
# Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit()
session.expunge_all()
_perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
@ -1652,16 +1695,37 @@ async def stream_resume_chat(
yield streaming_service.format_done()
finally:
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
with anyio.CancelScope(shield=True):
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
try:
async with shielded_async_session() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
agent = llm = connector_service = sandbox_backend = None
stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
_perf_log.info(
"[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
collected,
chat_id,
)
trim_native_heap()
log_system_snapshot("stream_resume_chat_END")

View file

@ -0,0 +1,174 @@
"""
Centralized performance monitoring for SurfSense backend.
Provides:
- A shared [PERF] logger used across all modules
- perf_timer context manager for timing code blocks
- perf_async_timer for async code blocks
- system_snapshot() for CPU/memory profiling
- RequestPerfMiddleware for per-request timing
"""
import gc
import logging
import os
import time
from contextlib import asynccontextmanager, contextmanager
from typing import Any
_perf_log: logging.Logger | None = None
_last_rss_mb: float = 0.0
def get_perf_logger() -> logging.Logger:
"""Return the singleton [PERF] logger, creating it once on first call."""
global _perf_log
if _perf_log is None:
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
if not _perf_log.handlers:
h = logging.StreamHandler()
h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
_perf_log.addHandler(h)
_perf_log.propagate = False
return _perf_log
@contextmanager
def perf_timer(label: str, *, extra: dict[str, Any] | None = None):
"""Synchronous context manager that logs elapsed wall-clock time.
Usage:
with perf_timer("[my_func] heavy computation"):
...
"""
log = get_perf_logger()
t0 = time.perf_counter()
yield
elapsed = time.perf_counter() - t0
suffix = ""
if extra:
suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items())
log.info("%s in %.3fs%s", label, elapsed, suffix)
@asynccontextmanager
async def perf_async_timer(label: str, *, extra: dict[str, Any] | None = None):
"""Async context manager that logs elapsed wall-clock time.
Usage:
async with perf_async_timer("[search] vector search"):
...
"""
log = get_perf_logger()
t0 = time.perf_counter()
yield
elapsed = time.perf_counter() - t0
suffix = ""
if extra:
suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items())
log.info("%s in %.3fs%s", label, elapsed, suffix)
def system_snapshot() -> dict[str, Any]:
"""Capture a lightweight CPU + memory snapshot of the current process.
Returns a dict with:
- rss_mb: Resident Set Size in MB
- rss_delta_mb: Change in RSS since the last snapshot
- cpu_percent: CPU usage % since last call (per-process)
- threads: number of active threads
- open_fds: number of open file descriptors (Linux only)
- asyncio_tasks: number of asyncio tasks currently alive
- gc_counts: tuple of object counts per gc generation
"""
import asyncio
global _last_rss_mb
snapshot: dict[str, Any] = {}
try:
import psutil
proc = psutil.Process(os.getpid())
mem = proc.memory_info()
rss_mb = round(mem.rss / 1024 / 1024, 1)
snapshot["rss_mb"] = rss_mb
snapshot["rss_delta_mb"] = (
round(rss_mb - _last_rss_mb, 1) if _last_rss_mb else 0.0
)
_last_rss_mb = rss_mb
snapshot["cpu_percent"] = proc.cpu_percent(interval=None)
snapshot["threads"] = proc.num_threads()
try:
snapshot["open_fds"] = proc.num_fds()
except AttributeError:
snapshot["open_fds"] = -1
except ImportError:
snapshot["rss_mb"] = -1
snapshot["rss_delta_mb"] = 0.0
snapshot["cpu_percent"] = -1
snapshot["threads"] = -1
snapshot["open_fds"] = -1
try:
all_tasks = asyncio.all_tasks()
snapshot["asyncio_tasks"] = len(all_tasks)
except RuntimeError:
snapshot["asyncio_tasks"] = -1
snapshot["gc_counts"] = gc.get_count()
return snapshot
def log_system_snapshot(label: str = "system_snapshot") -> None:
"""Capture and log a system snapshot with memory delta tracking."""
snap = system_snapshot()
delta_str = ""
if snap["rss_delta_mb"]:
sign = "+" if snap["rss_delta_mb"] > 0 else ""
delta_str = f" delta={sign}{snap['rss_delta_mb']}MB"
get_perf_logger().info(
"[%s] rss=%.1fMB%s cpu=%.1f%% threads=%d fds=%d asyncio_tasks=%d gc=%s",
label,
snap["rss_mb"],
delta_str,
snap["cpu_percent"],
snap["threads"],
snap["open_fds"],
snap["asyncio_tasks"],
snap["gc_counts"],
)
if snap["rss_mb"] > 0 and snap["rss_delta_mb"] > 500:
get_perf_logger().warning(
"[MEMORY_SPIKE] %s: RSS jumped by %.1fMB (now %.1fMB). "
"Possible leak — check recent operations.",
label,
snap["rss_delta_mb"],
snap["rss_mb"],
)
def trim_native_heap() -> bool:
"""Ask glibc to return free heap pages to the OS via ``malloc_trim(0)``.
On Linux (glibc), ``free()`` does not release memory back to the OS if
it sits below the brk watermark. ``malloc_trim`` forces the allocator
to give back as many freed pages as possible.
Returns True if trimming was performed, False otherwise (non-Linux or
libc unavailable).
"""
import ctypes
import sys
if sys.platform != "linux":
return False
try:
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
return True
except (OSError, AttributeError):
return False

View file

@ -181,7 +181,7 @@ export default function MorePagesPage() {
</DialogHeader>
<div className="flex flex-col items-center gap-4 py-4">
<Link
href="https://calendly.com/eric-surfsense/surfsense-meeting"
href="https://cal.com/mod-rohan"
target="_blank"
rel="noopener noreferrer"
className="flex w-full items-center justify-center gap-2 rounded-lg bg-primary px-4 py-2.5 text-sm font-medium text-primary-foreground transition hover:bg-primary/90"
@ -195,11 +195,11 @@ export default function MorePagesPage() {
<span className="h-px w-8 bg-border" />
</div>
<Link
href="mailto:eric@surfsense.com"
href="mailto:rohan@surfsense.com"
className="flex items-center gap-2 text-sm text-muted-foreground transition hover:text-foreground"
>
<IconMailFilled className="h-4 w-4" />
eric@surfsense.com
rohan@surfsense.com
</Link>
</div>
</DialogContent>

View file

@ -235,4 +235,3 @@ button {
@source '../node_modules/streamdown/dist/*.js';
@source '../node_modules/@streamdown/code/dist/*.js';
@source '../node_modules/@streamdown/math/dist/*.js';

View file

@ -5,6 +5,7 @@ import { AlertTriangle, Cable, Settings } from "lucide-react";
import Link from "next/link";
import { useSearchParams } from "next/navigation";
import type { FC } from "react";
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
import {
globalNewLLMConfigsAtom,
llmPreferencesAtom,
@ -19,7 +20,6 @@ import { Spinner } from "@/components/ui/spinner";
import { Tabs, TabsContent } from "@/components/ui/tabs";
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
import { useConnectorsElectric } from "@/hooks/use-connectors-electric";
import { useDocuments } from "@/hooks/use-documents";
import { useInbox } from "@/hooks/use-inbox";
import { cn } from "@/lib/utils";
import { ConnectorDialogHeader } from "./connector-popup/components/connector-dialog-header";
@ -62,10 +62,9 @@ export const ConnectorIndicator: FC<{ hideTrigger?: boolean }> = ({ hideTrigger
const llmConfigLoading = preferencesLoading || globalConfigsLoading;
// Fetch document type counts using Electric SQL + PGlite for real-time updates
const { typeCounts: documentTypeCounts, loading: documentTypesLoading } = useDocuments(
searchSpaceId ? Number(searchSpaceId) : null
);
// Fetch document type counts via the lightweight /type-counts endpoint (cached 10 min)
const { data: documentTypeCounts, isFetching: documentTypesLoading } =
useAtomValue(documentTypeCountsAtom);
// Fetch notifications to detect indexing failures
const { inboxItems = [] } = useInbox(

View file

@ -2,27 +2,28 @@ import { EnumConnectorName } from "@/contracts/enums/connector";
// OAuth Connectors (Quick Connect)
export const OAUTH_CONNECTORS = [
{
id: "google-drive-connector",
title: "Google Drive",
description: "Search your Drive files",
connectorType: EnumConnectorName.GOOGLE_DRIVE_CONNECTOR,
authEndpoint: "/api/v1/auth/google/drive/connector/add/",
},
{
id: "google-gmail-connector",
title: "Gmail",
description: "Search through your emails",
connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
authEndpoint: "/api/v1/auth/google/gmail/connector/add/",
},
{
id: "google-calendar-connector",
title: "Google Calendar",
description: "Search through your events",
connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
authEndpoint: "/api/v1/auth/google/calendar/connector/add/",
},
// // Uncomment for managed Google Connections
// {
// id: "google-drive-connector",
// title: "Google Drive",
// description: "Search your Drive files",
// connectorType: EnumConnectorName.GOOGLE_DRIVE_CONNECTOR,
// authEndpoint: "/api/v1/auth/google/drive/connector/add/",
// },
// {
// id: "google-gmail-connector",
// title: "Gmail",
// description: "Search through your emails",
// connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
// authEndpoint: "/api/v1/auth/google/gmail/connector/add/",
// },
// {
// id: "google-calendar-connector",
// title: "Google Calendar",
// description: "Search through your events",
// connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
// authEndpoint: "/api/v1/auth/google/calendar/connector/add/",
// },
{
id: "airtable-connector",
title: "Airtable",

View file

@ -65,6 +65,7 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import type { Document } from "@/contracts/types/document.types";
import { useBatchCommentsPreload } from "@/hooks/use-comments";
import { useCommentsElectric } from "@/hooks/use-comments-electric";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { cn } from "@/lib/utils";
@ -309,6 +310,22 @@ const Composer: FC = () => {
// Sync comments for the entire thread via Electric SQL (one subscription per thread)
useCommentsElectric(threadId);
// Batch-prefetch comments for all assistant messages so individual useComments
// hooks never fire their own network requests (eliminates N+1 API calls).
// Return a primitive string from the selector so useSyncExternalStore can
// compare snapshots by value and avoid infinite re-render loops.
const assistantIdsKey = useAssistantState(({ thread }) =>
thread.messages
.filter((m) => m.role === "assistant" && m.id?.startsWith("msg-"))
.map((m) => m.id!.replace("msg-", ""))
.join(",")
);
const assistantDbMessageIds = useMemo(
() => (assistantIdsKey ? assistantIdsKey.split(",").map(Number) : []),
[assistantIdsKey]
);
useBatchCommentsPreload(assistantDbMessageIds);
// Auto-focus editor on new chat page after mount
useEffect(() => {
if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) {

View file

@ -23,12 +23,12 @@ export function ContactFormGridWithDetails() {
We'd love to hear from you!
</p>
<p className="mt-4 max-w-lg text-center text-base text-neutral-600 dark:text-neutral-400">
Schedule a meeting with our Head of Product, Eric Lammertsma, or send us an email.
Schedule a meeting with us, or send us an email.
</p>
<div className="mt-10 flex flex-col items-center gap-6">
<Link
href="https://calendly.com/eric-surfsense/surfsense-meeting"
href="https://cal.com/mod-rohan"
target="_blank"
rel="noopener noreferrer"
className="flex items-center gap-3 rounded-xl bg-gradient-to-b from-blue-500 to-blue-600 px-6 py-3 text-base font-medium text-white shadow-lg transition duration-200 hover:from-blue-600 hover:to-blue-700"
@ -44,11 +44,11 @@ export function ContactFormGridWithDetails() {
</div>
<Link
href="mailto:eric@surfsense.com"
href="mailto:rohan@surfsense.com"
className="flex items-center gap-2 text-base text-neutral-600 transition duration-200 hover:text-neutral-900 dark:text-neutral-400 dark:hover:text-neutral-200"
>
<IconMailFilled className="h-5 w-5" />
eric@surfsense.com
rohan@surfsense.com
</Link>
</div>

View file

@ -96,12 +96,9 @@ export function HeroSection() {
</div>
)}
</h2>
{/* // TODO:aCTUAL DESCRITION */}
<p className="relative z-50 mx-auto mt-4 max-w-xl px-4 text-center text-base/6 text-gray-600 dark:text-gray-200">
Connect any AI to your documents, Drive, Notion and more,
</p>
<p className="relative z-50 mx-auto mt-0 max-w-xl px-4 text-center text-base/6 text-gray-600 dark:text-gray-200">
then chat with it, generate podcasts and reports, or even invite your team.
<p className="relative z-50 mx-auto mt-4 max-w-lg px-6 text-center text-sm leading-relaxed text-gray-600 sm:text-base sm:leading-relaxed md:max-w-xl md:text-lg md:leading-relaxed dark:text-gray-200">
Connect any LLM to your internal knowledge sources and chat with it in real time alongside
your team.
</p>
<div className="mb-6 mt-6 flex w-full flex-col items-center justify-center gap-4 px-8 sm:flex-row md:mb-10">
<GetStartedButton />

View file

@ -4,7 +4,6 @@ import {
IconBrandGithub,
IconBrandReddit,
IconMenu2,
IconSpeakerphone,
IconX,
} from "@tabler/icons-react";
import { AnimatePresence, motion } from "motion/react";
@ -13,7 +12,6 @@ import { useEffect, useState } from "react";
import { SignInButton } from "@/components/auth/sign-in-button";
import { Logo } from "@/components/Logo";
import { ThemeTogglerComponent } from "@/components/theme/theme-toggle";
import { useAnnouncements } from "@/hooks/use-announcements";
import { useGithubStars } from "@/hooks/use-github-stars";
import { cn } from "@/lib/utils";
@ -49,11 +47,7 @@ export const Navbar = () => {
const DesktopNav = ({ navItems, isScrolled }: any) => {
const [hovered, setHovered] = useState<number | null>(null);
const [mounted, setMounted] = useState(false);
const { compactFormat: githubStars, loading: loadingGithubStars } = useGithubStars();
const { unreadCount } = useAnnouncements();
useEffect(() => setMounted(true), []);
return (
<motion.div
onMouseLeave={() => {
@ -124,17 +118,6 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
</span>
)}
</Link>
<Link
href="/announcements"
className="relative hidden rounded-full p-2 hover:bg-gray-100 dark:hover:bg-neutral-800 transition-colors md:flex items-center justify-center"
>
<IconSpeakerphone className="h-5 w-5 text-neutral-600 dark:text-neutral-300" />
{mounted && unreadCount > 0 && (
<span className="absolute -top-0.5 -right-0.5 flex h-4 min-w-4 items-center justify-center rounded-full bg-red-500 px-1 text-[10px] font-bold text-white">
{unreadCount > 99 ? "99+" : unreadCount}
</span>
)}
</Link>
<ThemeTogglerComponent />
<SignInButton variant="desktop" />
</div>
@ -144,11 +127,7 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
const MobileNav = ({ navItems, isScrolled }: any) => {
const [open, setOpen] = useState(false);
const [mounted, setMounted] = useState(false);
const { compactFormat: githubStars, loading: loadingGithubStars } = useGithubStars();
const { unreadCount } = useAnnouncements();
useEffect(() => setMounted(true), []);
return (
<motion.div
@ -233,17 +212,6 @@ const MobileNav = ({ navItems, isScrolled }: any) => {
</span>
)}
</Link>
<Link
href="/announcements"
className="relative flex items-center justify-center rounded-lg p-2 hover:bg-gray-100 dark:hover:bg-neutral-800 transition-colors touch-manipulation"
>
<IconSpeakerphone className="h-5 w-5 text-neutral-600 dark:text-neutral-300" />
{mounted && unreadCount > 0 && (
<span className="absolute -top-0.5 -right-0.5 flex h-4 min-w-4 items-center justify-center rounded-full bg-red-500 px-1 text-[10px] font-bold text-white">
{unreadCount > 99 ? "99+" : unreadCount}
</span>
)}
</Link>
<ThemeTogglerComponent />
</div>
<SignInButton variant="mobile" />

View file

@ -8,43 +8,34 @@ const demoPlans = [
price: "0",
yearlyPrice: "0",
period: "",
billingText: "Includes 30 day PRO trial",
billingText: "",
features: [
"Open source on GitHub",
"Self Hostable",
"Upload and chat with 300+ pages of content",
"Connects with 8 popular sources, like Drive and Notion",
"Includes limited access to ChatGPT, Claude, and DeepSeek models",
"Supports 100+ more LLMs, including Gemini, Llama and many more",
"50+ File extensions supported",
"Generate podcasts in seconds",
"Cross-Browser Extension for dynamic webpages including authenticated content",
"Includes access to ChatGPT text and audio models",
"Realtime Collaborative Group Chats with teammates",
"Community support on Discord",
],
description: "Powerful features with some limitations",
description: "",
buttonText: "Get Started",
href: "/",
href: "/login",
isPopular: false,
},
{
name: "PRO",
price: "10",
yearlyPrice: "10",
period: "user / month",
billingText: "billed annually",
price: "0",
yearlyPrice: "0",
period: "",
billingText: "Free during beta",
features: [
"Everything in Free",
"Upload and chat with 5,000+ pages of content per user",
"Connects with 15+ external sources, like Slack and Airtable",
"Includes extended access to ChatGPT, Claude, and DeepSeek models",
"Collaboration and commenting features",
"Shared BYOK (Bring Your Own Key)",
"Team and role management",
"Planned: Centralized billing",
"Priority support",
"Includes 6000+ pages of content",
"Access to more models and providers",
"Priority support on Discord",
],
description: "The AI knowledge base for individuals and teams",
buttonText: "Upgrade",
href: "/contact",
description: "",
buttonText: "Get Started",
href: "/login",
isPopular: true,
},
{
@ -55,12 +46,9 @@ const demoPlans = [
billingText: "",
features: [
"Everything in Pro",
"Connect and chat with virtually unlimited pages of content",
"Limit models and/or providers",
"On-prem or VPC deployment",
"Planned: Audit logs and compliance",
"Planned: SSO, OIDC & SAML",
"Planned: Role-based access control (RBAC)",
"Audit logs and compliance",
"SSO, OIDC & SAML",
"White-glove setup and deployment",
"Monthly managed updates and maintenance",
"SLA commitments",

View file

@ -111,8 +111,8 @@ const FILE_TYPE_CONFIG: Record<string, Record<string, string[]>> = {
const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5";
// Upload limits
const MAX_FILES = 10;
// Upload limits — files are sent in batches of 5 to avoid proxy timeouts
const MAX_FILES = 50;
const MAX_TOTAL_SIZE_MB = 200;
const MAX_TOTAL_SIZE_BYTES = MAX_TOTAL_SIZE_MB * 1024 * 1024;

View file

@ -82,6 +82,22 @@ export const getCommentsResponse = z.object({
total_count: z.number(),
});
/**
* Batch-fetch comments for multiple messages
*/
export const getBatchCommentsRequest = z.object({
message_ids: z.array(z.number()).min(1).max(200),
});
export const commentListResponse = z.object({
comments: z.array(comment),
total_count: z.number(),
});
export const getBatchCommentsResponse = z.object({
comments_by_message: z.record(z.string(), commentListResponse),
});
/**
* Create comment
*/
@ -145,6 +161,8 @@ export type MentionComment = z.infer<typeof mentionComment>;
export type Mention = z.infer<typeof mention>;
export type GetCommentsRequest = z.infer<typeof getCommentsRequest>;
export type GetCommentsResponse = z.infer<typeof getCommentsResponse>;
export type GetBatchCommentsRequest = z.infer<typeof getBatchCommentsRequest>;
export type GetBatchCommentsResponse = z.infer<typeof getBatchCommentsResponse>;
export type CreateCommentRequest = z.infer<typeof createCommentRequest>;
export type CreateCommentResponse = z.infer<typeof createCommentResponse>;
export type CreateReplyRequest = z.infer<typeof createReplyRequest>;

View file

@ -1,4 +1,5 @@
import { useQuery } from "@tanstack/react-query";
import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef } from "react";
import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
@ -7,12 +8,109 @@ interface UseCommentsOptions {
enabled?: boolean;
}
// ---------------------------------------------------------------------------
// Module-level coordination: when a batch request is in-flight, individual
// useComments queryFns piggy-back on it instead of making their own requests.
//
// _batchReady is a promise that resolves once the batch useEffect has had a
// chance to set _batchInflight. Individual queryFns await this gate before
// deciding whether to piggy-back or fetch on their own, eliminating the
// previous race where setTimeout(0) was not enough.
// ---------------------------------------------------------------------------
let _batchInflight: Promise<void> | null = null;
let _batchTargetIds = new Set<number>();
let _batchReady: Promise<void> | null = null;
let _resolveBatchReady: (() => void) | null = null;
function resetBatchGate() {
_batchReady = new Promise<void>((r) => {
_resolveBatchReady = r;
});
}
// Open the initial gate immediately (no batch pending yet)
resetBatchGate();
_resolveBatchReady?.();
export function useComments({ messageId, enabled = true }: UseCommentsOptions) {
const queryClient = useQueryClient();
return useQuery({
queryKey: cacheKeys.comments.byMessage(messageId),
queryFn: async () => {
// Wait for the batch gate so the useEffect in useBatchCommentsPreload
// has a chance to set _batchInflight before we decide.
if (_batchReady) {
await _batchReady;
}
if (_batchInflight && _batchTargetIds.has(messageId)) {
await _batchInflight;
const cached = queryClient.getQueryData(cacheKeys.comments.byMessage(messageId));
if (cached) return cached;
}
return chatCommentsApiService.getComments({ message_id: messageId });
},
enabled: enabled && !!messageId,
staleTime: 30_000,
});
}
/**
* Batch-fetch comments for all given message IDs in a single request, then
* seed the per-message React Query cache so individual useComments hooks
* resolve from cache instead of firing their own requests.
*/
export function useBatchCommentsPreload(messageIds: number[]) {
const queryClient = useQueryClient();
const prevKeyRef = useRef<string>("");
useEffect(() => {
if (!messageIds.length) return;
const key = messageIds
.slice()
.sort((a, b) => a - b)
.join(",");
if (key === prevKeyRef.current) return;
prevKeyRef.current = key;
// Open a new gate so individual queryFns wait for us
resetBatchGate();
_batchTargetIds = new Set(messageIds);
let cancelled = false;
const promise = chatCommentsApiService
.getBatchComments({ message_ids: messageIds })
.then((data) => {
if (cancelled) return;
for (const [msgIdStr, commentList] of Object.entries(data.comments_by_message)) {
queryClient.setQueryData(cacheKeys.comments.byMessage(Number(msgIdStr)), commentList);
}
})
.catch(() => {
// Batch failed; individual queryFns will fall through to their own fetch
})
.finally(() => {
if (_batchInflight === promise) {
_batchInflight = null;
_batchTargetIds = new Set();
}
});
_batchInflight = promise;
// Release the gate — individual queryFns can now check _batchInflight
_resolveBatchReady?.();
return () => {
cancelled = true;
if (_batchInflight === promise) {
_batchInflight = null;
_batchTargetIds = new Set();
}
};
}, [messageIds, queryClient]);
}

View file

@ -1,5 +1,6 @@
"use client";
import { useQuery } from "@tanstack/react-query";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { documentsApiService } from "@/lib/apis/documents-api.service";
@ -183,56 +184,47 @@ export function useDocuments(
[]
);
// EFFECT 1: Load ALL documents from API (PRIMARY source of truth)
// No type filter — always fetches everything so typeCounts stay complete
// STEP 1: Load ALL documents from API (PRIMARY source of truth).
// Uses React Query for automatic deduplication, caching, and staleTime so
// multiple components mounting useDocuments(sameId) share a single request.
const {
data: apiResponse,
isLoading: apiLoading,
error: apiError,
} = useQuery({
queryKey: ["documents", "all", searchSpaceId],
queryFn: () =>
documentsApiService.getDocuments({
queryParams: {
search_space_id: searchSpaceId!,
page: 0,
page_size: -1,
},
}),
enabled: !!searchSpaceId,
staleTime: 30_000,
});
// Seed local state from API response (runs once per fresh fetch)
useEffect(() => {
if (!searchSpaceId) {
setLoading(false);
return;
if (!apiResponse) return;
populateUserCache(apiResponse.items);
const docs = apiResponse.items.map(apiToDisplayDoc);
setAllDocuments(docs);
apiLoadedRef.current = true;
setError(null);
}, [apiResponse, populateUserCache, apiToDisplayDoc]);
// Propagate loading / error from React Query
useEffect(() => {
setLoading(apiLoading);
}, [apiLoading]);
useEffect(() => {
if (apiError) {
setError(apiError instanceof Error ? apiError : new Error("Failed to load documents"));
}
// Capture validated value for async closure
const spaceId = searchSpaceId;
let mounted = true;
apiLoadedRef.current = false;
async function loadFromApi() {
try {
setLoading(true);
console.log("[useDocuments] Loading from API (source of truth):", spaceId);
const response = await documentsApiService.getDocuments({
queryParams: {
search_space_id: spaceId,
page: 0,
page_size: -1, // Fetch all documents (unfiltered)
},
});
if (!mounted) return;
populateUserCache(response.items);
const docs = response.items.map(apiToDisplayDoc);
setAllDocuments(docs);
apiLoadedRef.current = true;
setError(null);
console.log("[useDocuments] API loaded", docs.length, "documents");
} catch (err) {
if (!mounted) return;
console.error("[useDocuments] API load failed:", err);
setError(err instanceof Error ? err : new Error("Failed to load documents"));
} finally {
if (mounted) setLoading(false);
}
}
loadFromApi();
return () => {
mounted = false;
};
}, [searchSpaceId, populateUserCache, apiToDisplayDoc]);
}, [apiError]);
// EFFECT 2: Start Electric sync + live query for real-time updates
// No type filter — syncs and queries ALL documents; filtering is client-side

View file

@ -8,8 +8,11 @@ import {
type DeleteCommentRequest,
deleteCommentRequest,
deleteCommentResponse,
type GetBatchCommentsRequest,
type GetCommentsRequest,
type GetMentionsRequest,
getBatchCommentsRequest,
getBatchCommentsResponse,
getCommentsRequest,
getCommentsResponse,
getMentionsRequest,
@ -22,6 +25,22 @@ import { ValidationError } from "@/lib/error";
import { baseApiService } from "./base-api.service";
class ChatCommentsApiService {
/**
* Batch-fetch comments for multiple messages in one request
*/
getBatchComments = async (request: GetBatchCommentsRequest) => {
const parsed = getBatchCommentsRequest.safeParse(request);
if (!parsed.success) {
const errorMessage = parsed.error.issues.map((issue) => issue.message).join(", ");
throw new ValidationError(`Invalid request: ${errorMessage}`);
}
return baseApiService.post("/api/v1/messages/comments/batch", getBatchCommentsResponse, {
body: { message_ids: parsed.data.message_ids },
});
};
/**
* Get comments for a message
*/

View file

@ -109,7 +109,9 @@ class DocumentsApiService {
};
/**
* Upload document files
* Upload document files in batches to avoid proxy/LB timeouts.
* Files are split into chunks of UPLOAD_BATCH_SIZE and sent as separate
* requests. Results are aggregated into a single response.
*/
uploadDocument = async (request: UploadDocumentRequest) => {
const parsedRequest = uploadDocumentRequest.safeParse(request);
@ -121,17 +123,54 @@ class DocumentsApiService {
throw new ValidationError(`Invalid request: ${errorMessage}`);
}
// Create FormData for file upload
const formData = new FormData();
parsedRequest.data.files.forEach((file) => {
formData.append("files", file);
});
formData.append("search_space_id", String(parsedRequest.data.search_space_id));
formData.append("should_summarize", String(parsedRequest.data.should_summarize));
const { files, search_space_id, should_summarize } = parsedRequest.data;
const UPLOAD_BATCH_SIZE = 5;
return baseApiService.postFormData(`/api/v1/documents/fileupload`, uploadDocumentResponse, {
body: formData,
});
const batches: File[][] = [];
for (let i = 0; i < files.length; i += UPLOAD_BATCH_SIZE) {
batches.push(files.slice(i, i + UPLOAD_BATCH_SIZE));
}
const allDocumentIds: number[] = [];
const allDuplicateIds: number[] = [];
let totalFiles = 0;
let pendingFiles = 0;
let skippedDuplicates = 0;
for (const batch of batches) {
const formData = new FormData();
batch.forEach((file) => formData.append("files", file));
formData.append("search_space_id", String(search_space_id));
formData.append("should_summarize", String(should_summarize));
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 120_000);
try {
const result = await baseApiService.postFormData(
`/api/v1/documents/fileupload`,
uploadDocumentResponse,
{ body: formData, signal: controller.signal }
);
allDocumentIds.push(...(result.document_ids ?? []));
allDuplicateIds.push(...(result.duplicate_document_ids ?? []));
totalFiles += result.total_files ?? batch.length;
pendingFiles += result.pending_files ?? 0;
skippedDuplicates += result.skipped_duplicates ?? 0;
} finally {
clearTimeout(timeoutId);
}
}
return {
message: "Files uploaded for processing" as const,
document_ids: allDocumentIds,
duplicate_document_ids: allDuplicateIds,
total_files: totalFiles,
pending_files: pendingFiles,
skipped_duplicates: skippedDuplicates,
};
};
/**

View file

@ -1,3 +1,10 @@
import { QueryClient } from "@tanstack/react-query";
export const queryClient = new QueryClient();
export const queryClient = new QueryClient({
defaultOptions: {
queries: {
staleTime: 30_000,
refetchOnWindowFocus: false,
},
},
});

View file

@ -1,13 +1,12 @@
import { loader } from "fumadocs-core/source";
import { docs } from "@/.source/server";
import { icons } from "lucide-react";
import { createElement } from "react";
import { docs } from "@/.source/server";
export const source = loader({
baseUrl: "/docs",
source: docs.toFumadocsSource(),
icon(icon) {
if (icon && icon in icons)
return createElement(icons[icon as keyof typeof icons]);
if (icon && icon in icons) return createElement(icons[icon as keyof typeof icons]);
},
});