mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-10 20:35:17 +02:00
Merge remote-tracking branch 'upstream/dev' into refactor/upload-document-adapter-class
This commit is contained in:
commit
6d00b0debf
47 changed files with 1880 additions and 578 deletions
|
|
@ -88,6 +88,13 @@ ENV TMPDIR=/shared_tmp
|
|||
ENV PYTHONPATH=/app
|
||||
ENV UVICORN_LOOP=asyncio
|
||||
|
||||
# Tune glibc malloc to return freed memory to the OS more aggressively.
|
||||
# Without these, Python's gc.collect() frees objects but the underlying
|
||||
# C heap pages stay mapped (RSS never drops) due to sbrk fragmentation.
|
||||
ENV MALLOC_MMAP_THRESHOLD_=65536
|
||||
ENV MALLOC_TRIM_THRESHOLD_=131072
|
||||
ENV MALLOC_MMAP_MAX_=65536
|
||||
|
||||
# SERVICE_ROLE controls which process this container runs:
|
||||
# api – FastAPI backend only (runs migrations on startup)
|
||||
# worker – Celery worker only
|
||||
|
|
|
|||
|
|
@ -28,8 +28,9 @@ from app.agents.new_chat.system_prompt import (
|
|||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = logging.getLogger("surfsense.perf")
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
# =============================================================================
|
||||
# Connector Type Mapping
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
|
@ -389,7 +390,7 @@ def create_chat_litellm_from_agent_config(
|
|||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
return ChatLiteLLMRouter()
|
||||
return get_auto_mode_llm()
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
|
|||
|
||||
_daytona_client: Daytona | None = None
|
||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||
_SANDBOX_CACHE_MAX_SIZE = 20
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
|
||||
|
||||
|
|
@ -144,6 +145,12 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
|||
return cached
|
||||
sandbox = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
_sandbox_cache.pop(oldest_key, None)
|
||||
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
|
||||
|
||||
return sandbox
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ This module provides:
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -17,8 +19,152 @@ from langchain_core.tools import StructuredTool
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import async_session_maker
|
||||
from app.db import shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
# Connectors that call external live-search APIs (no local DB / embedding needed).
|
||||
# These are never filtered by available_document_types.
|
||||
_LIVE_SEARCH_CONNECTORS: set[str] = {
|
||||
"TAVILY_API",
|
||||
"SEARXNG_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
# Patterns that indicate the query has no meaningful search signal.
|
||||
# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
|
||||
# of '*' is random noise, so both keyword and semantic search degrade to
|
||||
# arbitrary ordering — large documents (many chunks) dominate by chance.
|
||||
_DEGENERATE_QUERY_RE = re.compile(
|
||||
r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
|
||||
)
|
||||
|
||||
# Max chunks per document when doing a recency-based browse instead of
|
||||
# a real search. We want breadth (many docs) over depth (many chunks).
|
||||
_BROWSE_MAX_CHUNKS_PER_DOC = 5
|
||||
|
||||
|
||||
def _is_degenerate_query(query: str) -> bool:
|
||||
"""Return True when the query carries no meaningful search signal.
|
||||
|
||||
Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
|
||||
strings, and single-character non-word tokens. These queries cause
|
||||
both keyword search (empty tsquery) and semantic search (meaningless
|
||||
embedding) to return effectively random results.
|
||||
"""
|
||||
stripped = query.strip()
|
||||
if not stripped:
|
||||
return True
|
||||
return bool(_DEGENERATE_QUERY_RE.match(stripped))
|
||||
|
||||
|
||||
async def _browse_recent_documents(
|
||||
search_space_id: int,
|
||||
document_type: str | None,
|
||||
top_k: int,
|
||||
start_date: datetime | None,
|
||||
end_date: datetime | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return the most-recent documents (recency-ordered, no search ranking).
|
||||
|
||||
Used as a fallback when the search query is degenerate (e.g. ``*``) and
|
||||
semantic / keyword search would produce arbitrary results. Returns
|
||||
document-grouped dicts in the same shape as ``_combined_rrf_search``
|
||||
so the rest of the pipeline works unchanged.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
base_conditions = [Document.search_space_id == search_space_id]
|
||||
|
||||
if document_type is not None:
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
if start_date is not None:
|
||||
base_conditions.append(Document.updated_at >= start_date)
|
||||
if end_date is not None:
|
||||
base_conditions.append(Document.updated_at <= end_date)
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
doc_query = (
|
||||
select(Document)
|
||||
.options(joinedload(Document.search_space))
|
||||
.where(*base_conditions)
|
||||
.order_by(Document.updated_at.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
result = await session.execute(doc_query)
|
||||
documents = result.scalars().unique().all()
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [d.id for d in documents]
|
||||
|
||||
chunk_query = (
|
||||
select(Chunk)
|
||||
.where(Chunk.document_id.in_(doc_ids))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunk_result = await session.execute(chunk_query)
|
||||
raw_chunks = chunk_result.scalars().all()
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if count < _BROWSE_MAX_CHUNKS_PER_DOC:
|
||||
doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for doc in documents:
|
||||
chunks_list = doc_chunks.get(doc.id, [])
|
||||
results.append(
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"content": "\n\n".join(
|
||||
c["content"] for c in chunks_list if c.get("content")
|
||||
),
|
||||
"score": 0.0,
|
||||
"chunks": chunks_list,
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
"metadata": doc.document_metadata or {},
|
||||
},
|
||||
"source": doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0,
|
||||
len(results),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Connector Constants and Normalization
|
||||
|
|
@ -182,9 +328,23 @@ _CHARS_PER_TOKEN = 4
|
|||
# Hard-floor / ceiling so the budget is always sensible regardless of what
|
||||
# the model reports.
|
||||
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
|
||||
_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens
|
||||
_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
|
||||
_MAX_CHUNK_CHARS = 8_000
|
||||
|
||||
# Rank-adaptive per-document budget allocation.
|
||||
# Top-ranked (most relevant) documents get a larger share of the budget so
|
||||
# we pack as much high-quality context as possible.
|
||||
#
|
||||
# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
|
||||
#
|
||||
# Examples (128K budget, 8K chunk cap):
|
||||
# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
|
||||
# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
|
||||
# rank 2 → 24% → 3 chunks |
|
||||
_TOP_DOC_BUDGET_FRACTION = 0.40
|
||||
_RANK_DECAY = 0.35
|
||||
_MIN_CHUNKS_PER_DOC = 3
|
||||
|
||||
|
||||
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
||||
"""Derive a character budget from the model's context window.
|
||||
|
|
@ -206,18 +366,24 @@ def format_documents_for_context(
|
|||
*,
|
||||
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
|
||||
max_chunk_chars: int = _MAX_CHUNK_CHARS,
|
||||
max_chunks_per_doc: int = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Format retrieved documents into a readable context string for the LLM.
|
||||
|
||||
Documents are added in order (highest relevance first) until the character
|
||||
budget is reached. Individual chunks are capped at ``max_chunk_chars`` so
|
||||
a single oversized chunk cannot monopolize the output.
|
||||
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
|
||||
each document is limited to a dynamically computed chunk cap so a single
|
||||
large document cannot monopolize the output while still maximising the use
|
||||
of available context space.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries from connector search
|
||||
max_chars: Approximate character budget for the entire output.
|
||||
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
|
||||
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
|
||||
auto-compute per document using a rank-adaptive formula so
|
||||
higher-ranked documents receive more chunks.
|
||||
|
||||
Returns:
|
||||
Formatted string with document contents and metadata
|
||||
|
|
@ -340,7 +506,23 @@ def format_documents_for_context(
|
|||
"<document_content>",
|
||||
]
|
||||
|
||||
for ch in g["chunks"]:
|
||||
# Rank-adaptive per-document chunk cap: top results get more chunks.
|
||||
if max_chunks_per_doc > 0:
|
||||
chunks_allowed = max_chunks_per_doc
|
||||
else:
|
||||
doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
|
||||
max_doc_chars = int(max_chars * doc_fraction)
|
||||
xml_overhead = 500
|
||||
chunks_allowed = max(
|
||||
(max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
|
||||
_MIN_CHUNKS_PER_DOC,
|
||||
)
|
||||
|
||||
chunks = g["chunks"]
|
||||
if len(chunks) > chunks_allowed:
|
||||
chunks = chunks[:chunks_allowed]
|
||||
|
||||
for ch in chunks:
|
||||
ch_content = ch["content"]
|
||||
if max_chunk_chars and len(ch_content) > max_chunk_chars:
|
||||
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
|
||||
|
|
@ -357,9 +539,11 @@ def format_documents_for_context(
|
|||
doc_xml = "\n".join(doc_lines)
|
||||
doc_len = len(doc_xml)
|
||||
|
||||
# Always include at least the first document; afterwards enforce budget.
|
||||
if doc_idx > 0 and total_chars + doc_len > max_chars:
|
||||
if total_chars + doc_len > max_chars:
|
||||
remaining = total_docs - doc_idx
|
||||
if doc_idx == 0:
|
||||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
parts.append(
|
||||
f"<!-- Output truncated: {remaining} more document(s) omitted "
|
||||
f"(budget {max_chars} chars). Refine your query or reduce top_k "
|
||||
|
|
@ -370,7 +554,15 @@ def format_documents_for_context(
|
|||
parts.append(doc_xml)
|
||||
total_chars += doc_len
|
||||
|
||||
return "\n".join(parts).strip()
|
||||
result = "\n".join(parts).strip()
|
||||
|
||||
# Hard safety net: if the result is still over budget (e.g. a single massive
|
||||
# first document), forcibly truncate with a closing comment.
|
||||
if len(result) > max_chars:
|
||||
truncation_msg = "\n<!-- ...output forcibly truncated to fit context window -->"
|
||||
result = result[: max_chars - len(truncation_msg)] + truncation_msg
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -388,6 +580,7 @@ async def search_knowledge_base_async(
|
|||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -406,12 +599,18 @@ async def search_knowledge_base_async(
|
|||
end_date: Optional end datetime (UTC) for filtering documents
|
||||
available_connectors: Optional list of connectors actually available in the search space.
|
||||
If provided, only these connectors will be searched.
|
||||
available_document_types: Optional list of document types that actually have indexed
|
||||
data. When provided, local connectors whose document type is
|
||||
absent are skipped entirely (no embedding / DB round-trip).
|
||||
max_input_tokens: Model context window size (tokens). Used to dynamically
|
||||
size the output so it fits within the model's limits.
|
||||
|
||||
Returns:
|
||||
Formatted string with search results
|
||||
"""
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
all_documents: list[dict[str, Any]] = []
|
||||
|
||||
# Resolve date range (default last 2 years)
|
||||
|
|
@ -424,88 +623,169 @@ async def search_knowledge_base_async(
|
|||
|
||||
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
||||
|
||||
connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"YOUTUBE_VIDEO": ("search_youtube", True, True, {}),
|
||||
"EXTENSION": ("search_extension", True, True, {}),
|
||||
"CRAWLED_URL": ("search_crawled_urls", True, True, {}),
|
||||
"FILE": ("search_files", True, True, {}),
|
||||
"SLACK_CONNECTOR": ("search_slack", True, True, {}),
|
||||
"TEAMS_CONNECTOR": ("search_teams", True, True, {}),
|
||||
"NOTION_CONNECTOR": ("search_notion", True, True, {}),
|
||||
"GITHUB_CONNECTOR": ("search_github", True, True, {}),
|
||||
"LINEAR_CONNECTOR": ("search_linear", True, True, {}),
|
||||
# --- Optimization 1: skip local connectors that have zero indexed documents ---
|
||||
if available_document_types:
|
||||
doc_types_set = set(available_document_types)
|
||||
before_count = len(connectors)
|
||||
connectors = [
|
||||
c for c in connectors if c in _LIVE_SEARCH_CONNECTORS or c in doc_types_set
|
||||
]
|
||||
skipped = before_count - len(connectors)
|
||||
if skipped:
|
||||
perf.info(
|
||||
"[kb_search] skipped %d empty connectors (had %d, now %d)",
|
||||
skipped,
|
||||
before_count,
|
||||
len(connectors),
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
|
||||
len(connectors),
|
||||
connectors[:5],
|
||||
search_space_id,
|
||||
top_k,
|
||||
)
|
||||
|
||||
# --- Fast-path: degenerate queries (*, **, empty, etc.) ---
|
||||
# Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
|
||||
# yields an empty tsquery, so both retrieval signals are useless.
|
||||
# Fall back to a recency-ordered browse that returns diverse results.
|
||||
if _is_degenerate_query(query):
|
||||
perf.info(
|
||||
"[kb_search] degenerate query %r detected - falling back to recency browse",
|
||||
query,
|
||||
)
|
||||
local_connectors = [c for c in connectors if c not in _LIVE_SEARCH_CONNECTORS]
|
||||
if not local_connectors:
|
||||
local_connectors = [None] # type: ignore[list-item]
|
||||
|
||||
browse_results = await asyncio.gather(
|
||||
*[
|
||||
_browse_recent_documents(
|
||||
search_space_id=search_space_id,
|
||||
document_type=c,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
)
|
||||
for c in local_connectors
|
||||
]
|
||||
)
|
||||
for docs in browse_results:
|
||||
all_documents.extend(docs)
|
||||
|
||||
# Skip dedup + formatting below (browse already returns unique docs)
|
||||
# but still cap output budget.
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
result = format_documents_for_context(
|
||||
all_documents,
|
||||
max_chars=output_budget,
|
||||
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
|
||||
)
|
||||
perf.info(
|
||||
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
|
||||
"budget=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(all_documents),
|
||||
len(result),
|
||||
output_budget,
|
||||
search_space_id,
|
||||
)
|
||||
return result
|
||||
|
||||
# Specs for live-search connectors (external APIs, no local DB/embedding).
|
||||
live_connector_specs: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"TAVILY_API": ("search_tavily", False, True, {}),
|
||||
"SEARXNG_API": ("search_searxng", False, True, {}),
|
||||
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
|
||||
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
|
||||
"DISCORD_CONNECTOR": ("search_discord", True, True, {}),
|
||||
"JIRA_CONNECTOR": ("search_jira", True, True, {}),
|
||||
"GOOGLE_CALENDAR_CONNECTOR": ("search_google_calendar", True, True, {}),
|
||||
"AIRTABLE_CONNECTOR": ("search_airtable", True, True, {}),
|
||||
"GOOGLE_GMAIL_CONNECTOR": ("search_google_gmail", True, True, {}),
|
||||
"GOOGLE_DRIVE_FILE": ("search_google_drive", True, True, {}),
|
||||
"CONFLUENCE_CONNECTOR": ("search_confluence", True, True, {}),
|
||||
"CLICKUP_CONNECTOR": ("search_clickup", True, True, {}),
|
||||
"LUMA_CONNECTOR": ("search_luma", True, True, {}),
|
||||
"ELASTICSEARCH_CONNECTOR": ("search_elasticsearch", True, True, {}),
|
||||
"NOTE": ("search_notes", True, True, {}),
|
||||
"BOOKSTACK_CONNECTOR": ("search_bookstack", True, True, {}),
|
||||
"CIRCLEBACK": ("search_circleback", True, True, {}),
|
||||
"OBSIDIAN_CONNECTOR": ("search_obsidian", True, True, {}),
|
||||
# Composio connectors
|
||||
"COMPOSIO_GOOGLE_DRIVE_CONNECTOR": (
|
||||
"search_composio_google_drive",
|
||||
True,
|
||||
True,
|
||||
{},
|
||||
),
|
||||
"COMPOSIO_GMAIL_CONNECTOR": ("search_composio_gmail", True, True, {}),
|
||||
"COMPOSIO_GOOGLE_CALENDAR_CONNECTOR": (
|
||||
"search_composio_google_calendar",
|
||||
True,
|
||||
True,
|
||||
{},
|
||||
),
|
||||
}
|
||||
|
||||
# Keep a conservative cap to avoid overloading DB/external services.
|
||||
# --- Optimization 2: compute the query embedding once, share across all local searches ---
|
||||
precomputed_embedding: list[float] | None = None
|
||||
has_local_connectors = any(c not in _LIVE_SEARCH_CONNECTORS for c in connectors)
|
||||
if has_local_connectors:
|
||||
from app.config import config as app_config
|
||||
|
||||
t_embed = time.perf_counter()
|
||||
precomputed_embedding = app_config.embedding_model_instance.embed(query)
|
||||
perf.info(
|
||||
"[kb_search] shared embedding computed in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
max_parallel_searches = 4
|
||||
semaphore = asyncio.Semaphore(max_parallel_searches)
|
||||
|
||||
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
|
||||
spec = connector_specs.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
is_live = connector in _LIVE_SEARCH_CONNECTORS
|
||||
|
||||
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
if includes_date_range:
|
||||
kwargs["start_date"] = resolved_start_date
|
||||
kwargs["end_date"] = resolved_end_date
|
||||
if is_live:
|
||||
spec = live_connector_specs.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
method_name, includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
if includes_date_range:
|
||||
kwargs["start_date"] = resolved_start_date
|
||||
kwargs["end_date"] = resolved_end_date
|
||||
|
||||
try:
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
_, chunks = await getattr(svc, method_name)(**kwargs)
|
||||
perf.info(
|
||||
"[kb_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_conn,
|
||||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
# --- Optimization 3: call _combined_rrf_search directly with shared embedding ---
|
||||
try:
|
||||
# Use isolated session per connector. Shared AsyncSession cannot safely
|
||||
# run concurrent DB operations.
|
||||
async with semaphore, async_session_maker() as isolated_session:
|
||||
isolated_connector_service = ConnectorService(
|
||||
isolated_session, search_space_id
|
||||
t_conn = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as isolated_session:
|
||||
svc = ConnectorService(isolated_session, search_space_id)
|
||||
chunks = await svc._combined_rrf_search(
|
||||
query_text=query,
|
||||
search_space_id=search_space_id,
|
||||
document_type=connector,
|
||||
top_k=top_k,
|
||||
start_date=resolved_start_date,
|
||||
end_date=resolved_end_date,
|
||||
query_embedding=precomputed_embedding,
|
||||
)
|
||||
perf.info(
|
||||
"[kb_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_conn,
|
||||
)
|
||||
connector_method = getattr(isolated_connector_service, method_name)
|
||||
_, chunks = await connector_method(**kwargs)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
print(f"Error searching connector {connector}: {e}")
|
||||
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
t_gather = time.perf_counter()
|
||||
connector_results = await asyncio.gather(
|
||||
*[_search_one_connector(connector) for connector in connectors]
|
||||
)
|
||||
perf.info(
|
||||
"[kb_search] all connectors gathered in %.3fs",
|
||||
time.perf_counter() - t_gather,
|
||||
)
|
||||
for chunks in connector_results:
|
||||
all_documents.extend(chunks)
|
||||
|
||||
|
|
@ -552,7 +832,28 @@ async def search_knowledge_base_async(
|
|||
deduplicated.append(doc)
|
||||
|
||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||
return format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
result = format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||
|
||||
if len(result) > output_budget:
|
||||
perf.warning(
|
||||
"[kb_search] output STILL exceeds budget after format (%d > %d), "
|
||||
"hard truncation should have fired",
|
||||
len(result),
|
||||
output_budget,
|
||||
)
|
||||
|
||||
perf.info(
|
||||
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
|
||||
"budget=%d max_input_tokens=%s space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(all_documents),
|
||||
len(deduplicated),
|
||||
len(result),
|
||||
output_budget,
|
||||
max_input_tokens,
|
||||
search_space_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
|
||||
|
|
@ -590,11 +891,15 @@ class SearchKnowledgeBaseInput(BaseModel):
|
|||
"""Input schema for the search_knowledge_base tool."""
|
||||
|
||||
query: str = Field(
|
||||
description="The search query - be specific and include key terms"
|
||||
description=(
|
||||
"The search query - use specific natural language terms. "
|
||||
"NEVER use wildcards like '*' or '**'; instead describe what you want "
|
||||
"(e.g. 'recent meeting notes' or 'project architecture overview')."
|
||||
),
|
||||
)
|
||||
top_k: int = Field(
|
||||
default=10,
|
||||
description="Number of results to retrieve (default: 10)",
|
||||
description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
|
||||
)
|
||||
start_date: str | None = Field(
|
||||
default=None,
|
||||
|
|
@ -657,6 +962,10 @@ Focus searches on these types for best results."""
|
|||
Use this tool to find documents, notes, files, web pages, and other content that may help answer the user's question.
|
||||
|
||||
IMPORTANT:
|
||||
- Always craft specific, descriptive search queries using natural language keywords.
|
||||
Good: "quarterly sales report Q3", "Python API authentication design".
|
||||
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
|
||||
- Prefer multiple focused searches over a single broad one with high top_k.
|
||||
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
|
||||
- If `connectors_to_search` is omitted/empty, the system will search broadly.
|
||||
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
|
||||
|
|
@ -672,6 +981,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
|
|||
|
||||
# Capture for closure
|
||||
_available_connectors = available_connectors
|
||||
_available_document_types = available_document_types
|
||||
|
||||
async def _search_knowledge_base_impl(
|
||||
query: str,
|
||||
|
|
@ -701,6 +1011,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
|
|||
start_date=parsed_start,
|
||||
end_date=parsed_end,
|
||||
available_connectors=_available_connectors,
|
||||
available_document_types=_available_document_types,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,9 +27,24 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
_MCP_CACHE_MAX_SIZE = 50
|
||||
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||
|
||||
|
||||
def _evict_expired_mcp_cache() -> None:
|
||||
"""Remove expired entries from the MCP tools cache to prevent unbounded growth."""
|
||||
now = time.monotonic()
|
||||
expired = [
|
||||
k
|
||||
for k, (ts, _) in _mcp_tools_cache.items()
|
||||
if now - ts >= _MCP_CACHE_TTL_SECONDS
|
||||
]
|
||||
for k in expired:
|
||||
del _mcp_tools_cache[k]
|
||||
if expired:
|
||||
logger.debug("Evicted %d expired MCP cache entries", len(expired))
|
||||
|
||||
|
||||
def _create_dynamic_input_model_from_schema(
|
||||
tool_name: str,
|
||||
input_schema: dict[str, Any],
|
||||
|
|
@ -392,6 +407,8 @@ async def load_mcp_tools(
|
|||
List of LangChain StructuredTool instances
|
||||
|
||||
"""
|
||||
_evict_expired_mcp_cache()
|
||||
|
||||
now = time.monotonic()
|
||||
cached = _mcp_tools_cache.get(search_space_id)
|
||||
if cached is not None:
|
||||
|
|
@ -445,6 +462,11 @@ async def load_mcp_tools(
|
|||
)
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
|
||||
if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
|
||||
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
||||
del _mcp_tools_cache[oldest_key]
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
return tools
|
||||
|
||||
|
|
|
|||
|
|
@ -145,10 +145,12 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
thread_id=deps["thread_id"],
|
||||
connector_service=deps.get("connector_service"),
|
||||
available_connectors=deps.get("available_connectors"),
|
||||
available_document_types=deps.get("available_document_types"),
|
||||
),
|
||||
requires=["search_space_id", "thread_id"],
|
||||
# connector_service and available_connectors are optional —
|
||||
# when missing, source_strategy="kb_search" degrades gracefully to "provided"
|
||||
# connector_service, available_connectors, and available_document_types
|
||||
# are optional — when missing, source_strategy="kb_search" degrades
|
||||
# gracefully to "provided"
|
||||
),
|
||||
# Link preview tool - fetches Open Graph metadata for URLs
|
||||
ToolDefinition(
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from langchain_core.callbacks import dispatch_custom_event
|
|||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.db import Report, async_session_maker
|
||||
from app.db import Report, shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
|
|
@ -559,6 +559,7 @@ def create_generate_report_tool(
|
|||
thread_id: int | None = None,
|
||||
connector_service: ConnectorService | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_report tool with injected dependencies.
|
||||
|
|
@ -716,7 +717,7 @@ def create_generate_report_tool(
|
|||
async def _save_failed_report(error_msg: str) -> int | None:
|
||||
"""Persist a failed report row using a short-lived session."""
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
async with shielded_async_session() as session:
|
||||
failed_report = Report(
|
||||
title=topic,
|
||||
content=None,
|
||||
|
|
@ -750,7 +751,7 @@ def create_generate_report_tool(
|
|||
# ── Phase 1: READ (short-lived session) ──────────────────────
|
||||
# Fetch parent report and LLM config, then close the session
|
||||
# so no DB connection is held during the long LLM call.
|
||||
async with async_session_maker() as read_session:
|
||||
async with shielded_async_session() as read_session:
|
||||
if parent_report_id:
|
||||
parent_report = await read_session.get(Report, parent_report_id)
|
||||
if parent_report:
|
||||
|
|
@ -827,7 +828,7 @@ def create_generate_report_tool(
|
|||
|
||||
# Run all queries in parallel, each with its own session
|
||||
async def _run_single_query(q: str) -> str:
|
||||
async with async_session_maker() as kb_session:
|
||||
async with shielded_async_session() as kb_session:
|
||||
kb_connector_svc = ConnectorService(
|
||||
kb_session, search_space_id
|
||||
)
|
||||
|
|
@ -838,6 +839,7 @@ def create_generate_report_tool(
|
|||
connector_service=kb_connector_svc,
|
||||
top_k=10,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
)
|
||||
|
||||
kb_results = await asyncio.gather(
|
||||
|
|
@ -1026,7 +1028,7 @@ def create_generate_report_tool(
|
|||
|
||||
# ── Phase 3: WRITE (short-lived session) ─────────────────────
|
||||
# Save the report to the database, then close the session.
|
||||
async with async_session_maker() as write_session:
|
||||
async with shielded_async_session() as write_session:
|
||||
report = Report(
|
||||
title=topic,
|
||||
content=report_content,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
|
@ -15,6 +16,9 @@ from slowapi.errors import RateLimitExceeded
|
|||
from slowapi.middleware import SlowAPIMiddleware
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
|
||||
from app.agents.new_chat.checkpointer import (
|
||||
|
|
@ -28,6 +32,7 @@ from app.routes.auth_routes import router as auth_router
|
|||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot
|
||||
|
||||
rate_limit_logger = logging.getLogger("surfsense.rate_limit")
|
||||
|
||||
|
|
@ -99,22 +104,24 @@ def _check_rate_limit_memory(
|
|||
now = time.monotonic()
|
||||
|
||||
with _memory_lock:
|
||||
# Evict timestamps outside the current window
|
||||
_memory_rate_limits[key] = [
|
||||
t for t in _memory_rate_limits[key] if now - t < window_seconds
|
||||
]
|
||||
timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]
|
||||
|
||||
if len(_memory_rate_limits[key]) >= max_requests:
|
||||
if not timestamps:
|
||||
_memory_rate_limits.pop(key, None)
|
||||
else:
|
||||
_memory_rate_limits[key] = timestamps
|
||||
|
||||
if len(timestamps) >= max_requests:
|
||||
rate_limit_logger.warning(
|
||||
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
|
||||
f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)"
|
||||
f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="RATE_LIMIT_EXCEEDED",
|
||||
)
|
||||
|
||||
_memory_rate_limits[key].append(now)
|
||||
_memory_rate_limits[key] = [*timestamps, now]
|
||||
|
||||
|
||||
def _check_rate_limit(
|
||||
|
|
@ -206,18 +213,16 @@ def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Enable slow-callback detection (set PERF_DEBUG=1 env var to activate)
|
||||
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
||||
# sooner (default 700/10/10 → 700/10/5). This reduces peak RSS
|
||||
# with minimal CPU overhead.
|
||||
gc.set_threshold(700, 10, 5)
|
||||
|
||||
_enable_slow_callback_logging(threshold_sec=0.5)
|
||||
# Not needed if you setup a migration system like Alembic
|
||||
await create_db_and_tables()
|
||||
# Setup LangGraph checkpointer tables for conversation persistence
|
||||
await setup_checkpointer_tables()
|
||||
# Initialize LLM Router for Auto mode load balancing
|
||||
initialize_llm_router()
|
||||
# Initialize Image Generation Router for Auto mode load balancing
|
||||
initialize_image_gen_router()
|
||||
# Seed Surfsense documentation (with timeout so a slow embedding API
|
||||
# doesn't block startup indefinitely and make the container unresponsive)
|
||||
try:
|
||||
await asyncio.wait_for(seed_surfsense_docs(), timeout=120)
|
||||
except TimeoutError:
|
||||
|
|
@ -225,8 +230,11 @@ async def lifespan(app: FastAPI):
|
|||
"Surfsense docs seeding timed out after 120s — skipping. "
|
||||
"Docs will be indexed on the next restart."
|
||||
)
|
||||
|
||||
log_system_snapshot("startup_complete")
|
||||
|
||||
yield
|
||||
# Cleanup: close checkpointer connection on shutdown
|
||||
|
||||
await close_checkpointer()
|
||||
|
||||
|
||||
|
|
@ -244,6 +252,63 @@ app = FastAPI(lifespan=lifespan)
|
|||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request-level performance middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logs wall-clock time, method, path, and status for every request so we can
|
||||
# spot slow endpoints in production logs.
|
||||
|
||||
_PERF_SLOW_REQUEST_THRESHOLD = float(
|
||||
__import__("os").environ.get("PERF_SLOW_REQUEST_MS", "2000")
|
||||
)
|
||||
|
||||
|
||||
class RequestPerfMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that logs per-request wall-clock time.
|
||||
|
||||
- ALL requests are logged at DEBUG level.
|
||||
- Requests exceeding PERF_SLOW_REQUEST_MS (default 2000ms) are logged at
|
||||
WARNING level with a system snapshot so we can correlate slow responses
|
||||
with CPU/memory usage at that moment.
|
||||
"""
|
||||
|
||||
async def dispatch(
|
||||
self, request: StarletteRequest, call_next: RequestResponseEndpoint
|
||||
) -> StarletteResponse:
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
response = await call_next(request)
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
status = response.status_code
|
||||
|
||||
perf.debug(
|
||||
"[request] %s %s -> %d in %.1fms",
|
||||
method,
|
||||
path,
|
||||
status,
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
if elapsed_ms > _PERF_SLOW_REQUEST_THRESHOLD:
|
||||
perf.warning(
|
||||
"[SLOW_REQUEST] %s %s -> %d in %.1fms (threshold=%.0fms)",
|
||||
method,
|
||||
path,
|
||||
status,
|
||||
elapsed_ms,
|
||||
_PERF_SLOW_REQUEST_THRESHOLD,
|
||||
)
|
||||
log_system_snapshot("slow_request")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(RequestPerfMiddleware)
|
||||
|
||||
# Add SlowAPI middleware for automatic rate limiting
|
||||
# Uses Starlette BaseHTTPMiddleware (not the raw ASGI variant) to avoid
|
||||
# corrupting StreamingResponse — SlowAPIASGIMiddleware re-sends
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
|
||||
import anyio
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
|
@ -1856,10 +1858,37 @@ class RefreshToken(Base, TimestampMixin):
|
|||
return not self.is_expired and not self.is_revoked
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
engine = create_async_engine(
|
||||
DATABASE_URL,
|
||||
pool_size=30,
|
||||
max_overflow=150,
|
||||
pool_recycle=1800,
|
||||
pool_pre_ping=True,
|
||||
pool_timeout=30,
|
||||
)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def shielded_async_session():
|
||||
"""Cancellation-safe async session context manager.
|
||||
|
||||
Starlette's BaseHTTPMiddleware cancels the task via an anyio cancel
|
||||
scope when a client disconnects. A plain ``async with async_session_maker()``
|
||||
has its ``__aexit__`` (which awaits ``session.close()``) cancelled by the
|
||||
scope, orphaning the underlying database connection.
|
||||
|
||||
This wrapper ensures ``session.close()`` always completes by running it
|
||||
inside ``anyio.CancelScope(shield=True)``.
|
||||
"""
|
||||
session = async_session_maker()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
await session.close()
|
||||
|
||||
|
||||
async def setup_indexes():
|
||||
async with engine.begin() as conn:
|
||||
# Create indexes
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import contextlib
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
|
@ -44,6 +45,7 @@ from app.indexing_pipeline.pipeline_logger import (
|
|||
log_retryable_llm_error,
|
||||
log_unexpected_error,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
||||
class IndexingPipelineService:
|
||||
|
|
@ -58,6 +60,9 @@ class IndexingPipelineService:
|
|||
"""
|
||||
Persist new documents and detect changes, returning only those that need indexing.
|
||||
"""
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
documents = []
|
||||
seen_hashes: set[str] = set()
|
||||
batch_ctx = PipelineLogContext(
|
||||
|
|
@ -140,11 +145,14 @@ class IndexingPipelineService:
|
|||
|
||||
try:
|
||||
await self.session.commit()
|
||||
perf.info(
|
||||
"[indexing] prepare_for_indexing in %.3fs input=%d output=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(connector_docs),
|
||||
len(documents),
|
||||
)
|
||||
return documents
|
||||
except IntegrityError:
|
||||
# A concurrent worker committed a document with the same content_hash
|
||||
# or unique_identifier_hash between our check and our INSERT.
|
||||
# The document already exists — roll back and let the next sync run handle it.
|
||||
log_race_condition(batch_ctx)
|
||||
await self.session.rollback()
|
||||
return []
|
||||
|
|
@ -165,26 +173,41 @@ class IndexingPipelineService:
|
|||
unique_id=connector_doc.unique_id,
|
||||
doc_id=document.id,
|
||||
)
|
||||
perf = get_perf_logger()
|
||||
t_index = time.perf_counter()
|
||||
try:
|
||||
log_index_started(ctx)
|
||||
document.status = DocumentStatus.processing()
|
||||
await self.session.commit()
|
||||
|
||||
t_step = time.perf_counter()
|
||||
if connector_doc.should_summarize and llm is not None:
|
||||
content = await summarize_document(
|
||||
connector_doc.source_markdown, llm, connector_doc.metadata
|
||||
)
|
||||
perf.info(
|
||||
"[indexing] summarize_document doc=%d in %.3fs",
|
||||
document.id,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
elif connector_doc.should_summarize and connector_doc.fallback_summary:
|
||||
content = connector_doc.fallback_summary
|
||||
else:
|
||||
content = connector_doc.source_markdown
|
||||
|
||||
t_step = time.perf_counter()
|
||||
embedding = embed_text(content)
|
||||
perf.debug(
|
||||
"[indexing] embed_text (summary) doc=%d in %.3fs",
|
||||
document.id,
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
await self.session.execute(
|
||||
delete(Chunk).where(Chunk.document_id == document.id)
|
||||
)
|
||||
|
||||
t_step = time.perf_counter()
|
||||
chunks = [
|
||||
Chunk(content=text, embedding=embed_text(text))
|
||||
for text in chunk_text(
|
||||
|
|
@ -192,6 +215,12 @@ class IndexingPipelineService:
|
|||
use_code_chunker=connector_doc.should_use_code_chunker,
|
||||
)
|
||||
]
|
||||
perf.info(
|
||||
"[indexing] chunk+embed doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_step,
|
||||
)
|
||||
|
||||
document.content = content
|
||||
document.embedding = embedding
|
||||
|
|
@ -199,6 +228,12 @@ class IndexingPipelineService:
|
|||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.ready()
|
||||
await self.session.commit()
|
||||
perf.info(
|
||||
"[indexing] index TOTAL doc=%d chunks=%d in %.3fs",
|
||||
document.id,
|
||||
len(chunks),
|
||||
time.perf_counter() - t_index,
|
||||
)
|
||||
log_index_success(ctx, chunk_count=len(chunks))
|
||||
|
||||
except RETRYABLE_LLM_ERRORS as e:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
||||
|
||||
|
||||
class ChucksHybridSearchRetriever:
|
||||
def __init__(self, db_session):
|
||||
|
|
@ -38,9 +43,17 @@ class ChucksHybridSearchRetriever:
|
|||
from app.config import config
|
||||
from app.db import Chunk, Document
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf.debug(
|
||||
"[chunk_search] vector_search embedding in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
# Build the query filtered by search space
|
||||
query = (
|
||||
|
|
@ -60,8 +73,16 @@ class ChucksHybridSearchRetriever:
|
|||
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
|
||||
|
||||
# Execute the query
|
||||
t_db = time.perf_counter()
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
perf.info(
|
||||
"[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d",
|
||||
time.perf_counter() - t_db,
|
||||
len(chunks),
|
||||
time.perf_counter() - t0,
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
|
@ -91,6 +112,9 @@ class ChucksHybridSearchRetriever:
|
|||
|
||||
from app.db import Chunk, Document
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector("english", Chunk.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
|
@ -118,6 +142,12 @@ class ChucksHybridSearchRetriever:
|
|||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
perf.info(
|
||||
"[chunk_search] full_text_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(chunks),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
|
@ -129,6 +159,7 @@ class ChucksHybridSearchRetriever:
|
|||
document_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Hybrid search that returns **documents** (not individual chunks).
|
||||
|
|
@ -143,6 +174,7 @@ class ChucksHybridSearchRetriever:
|
|||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
query_embedding: Pre-computed embedding vector. If None, will be computed here.
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing document data and relevance scores. Each dict contains:
|
||||
|
|
@ -157,9 +189,17 @@ class ChucksHybridSearchRetriever:
|
|||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if query_embedding is None:
|
||||
embedding_model = config.embedding_model_instance
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf.debug(
|
||||
"[chunk_search] hybrid_search embedding in %.3fs",
|
||||
time.perf_counter() - t_embed,
|
||||
)
|
||||
|
||||
# RRF constants
|
||||
k = 60
|
||||
|
|
@ -254,9 +294,17 @@ class ChucksHybridSearchRetriever:
|
|||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
# Execute the RRF query
|
||||
t_rrf = time.perf_counter()
|
||||
result = await self.db_session.execute(final_query)
|
||||
chunks_with_scores = result.all()
|
||||
perf.info(
|
||||
"[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s",
|
||||
time.perf_counter() - t_rrf,
|
||||
len(chunks_with_scores),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
|
||||
# If no results were found, return an empty list
|
||||
if not chunks_with_scores:
|
||||
|
|
@ -300,8 +348,9 @@ class ChucksHybridSearchRetriever:
|
|||
if not doc_ids:
|
||||
return []
|
||||
|
||||
# Fetch ALL chunks for selected documents in a single query so the final prompt can cite
|
||||
# any chunk from those documents.
|
||||
# Fetch chunks for selected documents. We cap per document to avoid
|
||||
# loading hundreds of chunks for a single large file while still
|
||||
# ensuring the chunks that matched the RRF query are always included.
|
||||
chunk_query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document))
|
||||
|
|
@ -311,7 +360,20 @@ class ChucksHybridSearchRetriever:
|
|||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunk_query)
|
||||
all_chunks = chunks_result.scalars().all()
|
||||
raw_chunks = chunks_result.scalars().all()
|
||||
|
||||
matched_chunk_ids: set[int] = {
|
||||
item["chunk_id"] for item in serialized_chunk_results
|
||||
}
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
all_chunks: list = []
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
|
||||
all_chunks.append(chunk)
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
# Assemble final doc-grouped results in the same order as doc_ids
|
||||
doc_map: dict[int, dict] = {
|
||||
|
|
@ -354,4 +416,11 @@ class ChucksHybridSearchRetriever:
|
|||
)
|
||||
final_docs.append(entry)
|
||||
|
||||
perf.info(
|
||||
"[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0,
|
||||
len(final_docs),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return final_docs
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
||||
|
||||
|
||||
class DocumentHybridSearchRetriever:
|
||||
def __init__(self, db_session):
|
||||
|
|
@ -38,6 +43,9 @@ class DocumentHybridSearchRetriever:
|
|||
from app.config import config
|
||||
from app.db import Document
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
|
@ -63,6 +71,12 @@ class DocumentHybridSearchRetriever:
|
|||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
perf.info(
|
||||
"[doc_search] vector_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(documents),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
|
@ -92,6 +106,9 @@ class DocumentHybridSearchRetriever:
|
|||
|
||||
from app.db import Document
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector("english", Document.content)
|
||||
tsquery = func.plainto_tsquery("english", query_text)
|
||||
|
|
@ -118,6 +135,12 @@ class DocumentHybridSearchRetriever:
|
|||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
perf.info(
|
||||
"[doc_search] full_text_search in %.3fs results=%d space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(documents),
|
||||
search_space_id,
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
|
@ -129,6 +152,7 @@ class DocumentHybridSearchRetriever:
|
|||
document_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Hybrid search that returns **documents** (not individual chunks).
|
||||
|
|
@ -143,7 +167,7 @@ class DocumentHybridSearchRetriever:
|
|||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
start_date: Optional start date for filtering documents by updated_at
|
||||
end_date: Optional end date for filtering documents by updated_at
|
||||
|
||||
query_embedding: Pre-computed embedding vector. If None, will be computed here.
|
||||
"""
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
|
@ -151,9 +175,12 @@ class DocumentHybridSearchRetriever:
|
|||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if query_embedding is None:
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# RRF constants
|
||||
k = 60
|
||||
|
|
@ -254,7 +281,8 @@ class DocumentHybridSearchRetriever:
|
|||
# Collect document IDs for chunk fetching
|
||||
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
|
||||
|
||||
# Fetch ALL chunks for these documents in a single query
|
||||
# Fetch chunks for these documents, capped per document to avoid
|
||||
# loading hundreds of chunks for a single large file.
|
||||
chunks_query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document))
|
||||
|
|
@ -262,7 +290,16 @@ class DocumentHybridSearchRetriever:
|
|||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_result = await self.db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
raw_chunks = chunks_result.scalars().all()
|
||||
|
||||
doc_chunk_counts: dict[int, int] = {}
|
||||
chunks: list = []
|
||||
for chunk in raw_chunks:
|
||||
did = chunk.document_id
|
||||
count = doc_chunk_counts.get(did, 0)
|
||||
if count < _MAX_FETCH_CHUNKS_PER_DOC:
|
||||
chunks.append(chunk)
|
||||
doc_chunk_counts[did] = count + 1
|
||||
|
||||
# Assemble doc-grouped results
|
||||
doc_map: dict[int, dict] = {
|
||||
|
|
@ -303,4 +340,11 @@ class DocumentHybridSearchRetriever:
|
|||
)
|
||||
final_docs.append(entry)
|
||||
|
||||
perf.info(
|
||||
"[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
|
||||
time.perf_counter() - t0,
|
||||
len(final_docs),
|
||||
search_space_id,
|
||||
document_type,
|
||||
)
|
||||
return final_docs
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.db import User, get_async_session
|
||||
from app.schemas.chat_comments import (
|
||||
CommentBatchRequest,
|
||||
CommentBatchResponse,
|
||||
CommentCreateRequest,
|
||||
CommentListResponse,
|
||||
CommentReplyResponse,
|
||||
|
|
@ -19,6 +21,7 @@ from app.services.chat_comments_service import (
|
|||
create_reply,
|
||||
delete_comment,
|
||||
get_comments_for_message,
|
||||
get_comments_for_messages_batch,
|
||||
get_user_mentions,
|
||||
update_comment,
|
||||
)
|
||||
|
|
@ -27,6 +30,16 @@ from app.users import current_active_user
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/messages/comments/batch", response_model=CommentBatchResponse)
|
||||
async def batch_list_comments(
|
||||
request: CommentBatchRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Batch-fetch comments for multiple messages in one request."""
|
||||
return await get_comments_for_messages_batch(session, request.message_ids, user)
|
||||
|
||||
|
||||
@router.get("/messages/{message_id}/comments", response_model=CommentListResponse)
|
||||
async def list_comments(
|
||||
message_id: int,
|
||||
|
|
|
|||
|
|
@ -133,6 +133,8 @@ async def create_documents_file_upload(
|
|||
|
||||
Requires DOCUMENTS_CREATE permission.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
|
||||
from app.db import DocumentStatus
|
||||
|
|
@ -143,7 +145,6 @@ async def create_documents_file_upload(
|
|||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
try:
|
||||
# Check permission
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
|
|
@ -179,69 +180,64 @@ async def create_documents_file_upload(
|
|||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
created_documents: list[Document] = []
|
||||
files_to_process: list[
|
||||
tuple[Document, str, str]
|
||||
] = [] # (document, temp_path, filename)
|
||||
skipped_duplicates = 0
|
||||
duplicate_document_ids: list[int] = []
|
||||
actual_total_size = 0
|
||||
# ===== Read all files concurrently to avoid blocking the event loop =====
|
||||
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
|
||||
"""Read upload content and write to temp file off the event loop."""
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
filename = file.filename or "unknown"
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
# This makes ALL documents visible in the UI immediately with pending status
|
||||
for file in files:
|
||||
try:
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
# Save file to temp location
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=os.path.splitext(file.filename or "")[1]
|
||||
) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
|
||||
actual_total_size += file_size
|
||||
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Generate unique identifier for deduplication check
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, file.filename or "unknown", search_space_id
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File '{filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
|
||||
def _write_temp() -> str:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=os.path.splitext(filename)[1]
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
return tmp.name
|
||||
|
||||
temp_path = await asyncio.to_thread(_write_temp)
|
||||
return temp_path, filename, file_size
|
||||
|
||||
saved_files = await asyncio.gather(*(_read_and_save(f) for f in files))
|
||||
|
||||
actual_total_size = sum(size for _, _, size in saved_files)
|
||||
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
for temp_path, _, _ in saved_files:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
created_documents: list[Document] = []
|
||||
files_to_process: list[tuple[Document, str, str]] = []
|
||||
skipped_duplicates = 0
|
||||
duplicate_document_ids: list[int] = []
|
||||
|
||||
for temp_path, filename, file_size in saved_files:
|
||||
try:
|
||||
unique_identifier_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, filename, search_space_id
|
||||
)
|
||||
|
||||
# Check if document already exists (by unique identifier)
|
||||
existing = await check_document_by_unique_identifier(
|
||||
session, unique_identifier_hash
|
||||
)
|
||||
if existing:
|
||||
if DocumentStatus.is_state(existing.status, DocumentStatus.READY):
|
||||
# True duplicate — content already indexed, skip
|
||||
os.unlink(temp_path)
|
||||
skipped_duplicates += 1
|
||||
duplicate_document_ids.append(existing.id)
|
||||
continue
|
||||
|
||||
# Existing document is stuck (failed/pending/processing)
|
||||
# Reset it to pending and re-dispatch for processing
|
||||
existing.status = DocumentStatus.pending()
|
||||
existing.content = "Processing..."
|
||||
existing.document_metadata = {
|
||||
|
|
@ -251,50 +247,45 @@ async def create_documents_file_upload(
|
|||
}
|
||||
existing.updated_at = get_current_timestamp()
|
||||
created_documents.append(existing)
|
||||
files_to_process.append(
|
||||
(existing, temp_path, file.filename or "unknown")
|
||||
)
|
||||
files_to_process.append((existing, temp_path, filename))
|
||||
continue
|
||||
|
||||
# Create pending document (visible immediately in UI via ElectricSQL)
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file.filename or "Uploaded File",
|
||||
title=filename if filename != "unknown" else "Uploaded File",
|
||||
document_type=DocumentType.FILE,
|
||||
document_metadata={
|
||||
"FILE_NAME": file.filename,
|
||||
"FILE_NAME": filename,
|
||||
"file_size": file_size,
|
||||
"upload_time": datetime.now().isoformat(),
|
||||
},
|
||||
content="Processing...", # Placeholder until processed
|
||||
content_hash=unique_identifier_hash, # Temporary, updated when ready
|
||||
content="Processing...",
|
||||
content_hash=unique_identifier_hash,
|
||||
unique_identifier_hash=unique_identifier_hash,
|
||||
embedding=None,
|
||||
status=DocumentStatus.pending(), # Shows "pending" in UI
|
||||
status=DocumentStatus.pending(),
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=str(user.id),
|
||||
)
|
||||
session.add(document)
|
||||
created_documents.append(document)
|
||||
files_to_process.append(
|
||||
(document, temp_path, file.filename or "unknown")
|
||||
)
|
||||
files_to_process.append((document, temp_path, filename))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to process file {file.filename}: {e!s}",
|
||||
detail=f"Failed to process file {filename}: {e!s}",
|
||||
) from e
|
||||
|
||||
# Commit all pending documents - they appear in UI immediately via ElectricSQL
|
||||
if created_documents:
|
||||
await session.commit()
|
||||
# Refresh to get generated IDs
|
||||
for doc in created_documents:
|
||||
await session.refresh(doc)
|
||||
|
||||
# ===== PHASE 2: Dispatch tasks for each file =====
|
||||
# Each task will update document status: pending → processing → ready/failed
|
||||
for document, temp_path, filename in files_to_process:
|
||||
await dispatcher.dispatch_file_processing(
|
||||
document_id=document.id,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from app.db import (
|
|||
SearchSpace,
|
||||
User,
|
||||
get_async_session,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.schemas.new_chat import (
|
||||
NewChatMessageAppend,
|
||||
|
|
@ -1092,13 +1093,18 @@ async def handle_new_chat(
|
|||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs usable.
|
||||
await session.commit()
|
||||
# Close the dependency session now so its connection returns to
|
||||
# the pool before streaming begins. Without this, Starlette's
|
||||
# BaseHTTPMiddleware cancels the scope on client disconnect and
|
||||
# the dependency generator's __aexit__ never runs, orphaning the
|
||||
# connection (the "Exception terminating connection" errors).
|
||||
await session.close()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
|
|
@ -1323,6 +1329,7 @@ async def regenerate_response(
|
|||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
|
||||
await session.commit()
|
||||
await session.close()
|
||||
|
||||
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
|
||||
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
||||
|
|
@ -1333,7 +1340,6 @@ async def regenerate_response(
|
|||
user_query=user_query_to_use,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=thread_id,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
|
|
@ -1344,29 +1350,35 @@ async def regenerate_response(
|
|||
current_user_display_name=user.display_name or "A team member",
|
||||
):
|
||||
yield chunk
|
||||
# If we get here, streaming completed successfully
|
||||
streaming_completed = True
|
||||
finally:
|
||||
# Only delete old messages if streaming completed successfully
|
||||
# This ensures we don't lose data on streaming failures
|
||||
if streaming_completed and messages_to_delete:
|
||||
# Only delete old messages if streaming completed successfully.
|
||||
# Uses a fresh session since stream_new_chat manages its own.
|
||||
if streaming_completed and message_ids_to_delete:
|
||||
try:
|
||||
for msg in messages_to_delete:
|
||||
await session.delete(msg)
|
||||
await session.commit()
|
||||
async with shielded_async_session() as cleanup_session:
|
||||
for msg_id in message_ids_to_delete:
|
||||
_res = await cleanup_session.execute(
|
||||
select(NewChatMessage).filter(
|
||||
NewChatMessage.id == msg_id
|
||||
)
|
||||
)
|
||||
_msg = _res.scalars().first()
|
||||
if _msg:
|
||||
await cleanup_session.delete(_msg)
|
||||
await cleanup_session.commit()
|
||||
|
||||
# Delete any public snapshots that contain the modified messages
|
||||
from app.services.public_chat_service import (
|
||||
delete_affected_snapshots,
|
||||
)
|
||||
from app.services.public_chat_service import (
|
||||
delete_affected_snapshots,
|
||||
)
|
||||
|
||||
await delete_affected_snapshots(
|
||||
session, thread_id, message_ids_to_delete
|
||||
)
|
||||
await delete_affected_snapshots(
|
||||
cleanup_session, thread_id, message_ids_to_delete
|
||||
)
|
||||
except Exception as cleanup_error:
|
||||
# Log but don't fail - the new messages are already streamed
|
||||
print(
|
||||
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
|
||||
_logger.warning(
|
||||
"[regenerate] Failed to delete old messages: %s",
|
||||
cleanup_error,
|
||||
)
|
||||
|
||||
# Return streaming response with checkpoint_id for rewinding
|
||||
|
|
@ -1440,13 +1452,13 @@ async def resume_chat(
|
|||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
# on searchspaces/documents for the entire duration of the stream.
|
||||
await session.commit()
|
||||
await session.close()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_resume_chat(
|
||||
chat_id=thread_id,
|
||||
search_space_id=request.search_space_id,
|
||||
decisions=decisions,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
thread_visibility=thread.visibility,
|
||||
|
|
|
|||
|
|
@ -87,6 +87,18 @@ class CommentListResponse(BaseModel):
|
|||
total_count: int
|
||||
|
||||
|
||||
class CommentBatchRequest(BaseModel):
|
||||
"""Request for batch-fetching comments for multiple messages."""
|
||||
|
||||
message_ids: list[int] = Field(..., min_length=1, max_length=200)
|
||||
|
||||
|
||||
class CommentBatchResponse(BaseModel):
|
||||
"""Batch response keyed by message_id."""
|
||||
|
||||
comments_by_message: dict[int, CommentListResponse]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mention Schemas
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.chat_comments import (
|
||||
AuthorResponse,
|
||||
CommentBatchResponse,
|
||||
CommentListResponse,
|
||||
CommentReplyResponse,
|
||||
CommentResponse,
|
||||
|
|
@ -264,6 +265,146 @@ async def get_comments_for_message(
|
|||
)
|
||||
|
||||
|
||||
async def get_comments_for_messages_batch(
|
||||
session: AsyncSession,
|
||||
message_ids: list[int],
|
||||
user: User,
|
||||
) -> CommentBatchResponse:
|
||||
"""
|
||||
Batch-fetch comments for multiple messages in a single DB round-trip.
|
||||
|
||||
Validates that all messages exist and belong to search spaces the user
|
||||
can read comments in, then loads all comments with eager-loaded authors
|
||||
and replies.
|
||||
"""
|
||||
if not message_ids:
|
||||
return CommentBatchResponse(comments_by_message={})
|
||||
|
||||
unique_ids = list(set(message_ids))
|
||||
|
||||
result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.thread))
|
||||
.filter(NewChatMessage.id.in_(unique_ids))
|
||||
)
|
||||
messages = result.scalars().all()
|
||||
msg_map = {m.id: m for m in messages}
|
||||
|
||||
search_space_ids = {m.thread.search_space_id for m in messages}
|
||||
permissions_cache: dict[int, set] = {}
|
||||
for ss_id in search_space_ids:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
ss_id,
|
||||
Permission.COMMENTS_READ.value,
|
||||
"You don't have permission to read comments in this search space",
|
||||
)
|
||||
permissions_cache[ss_id] = await get_user_permissions(session, user.id, ss_id)
|
||||
|
||||
result = await session.execute(
|
||||
select(ChatComment)
|
||||
.options(
|
||||
selectinload(ChatComment.author),
|
||||
selectinload(ChatComment.replies).selectinload(ChatComment.author),
|
||||
)
|
||||
.filter(
|
||||
ChatComment.message_id.in_(unique_ids),
|
||||
ChatComment.parent_id.is_(None),
|
||||
)
|
||||
.order_by(ChatComment.created_at)
|
||||
)
|
||||
top_level_comments = result.scalars().all()
|
||||
|
||||
all_mentioned_uuids: set[UUID] = set()
|
||||
for comment in top_level_comments:
|
||||
all_mentioned_uuids.update(parse_mentions(comment.content))
|
||||
for reply in comment.replies:
|
||||
all_mentioned_uuids.update(parse_mentions(reply.content))
|
||||
|
||||
user_names = await get_user_names_for_mentions(session, all_mentioned_uuids)
|
||||
|
||||
comments_by_msg: dict[int, list[ChatComment]] = {mid: [] for mid in unique_ids}
|
||||
for comment in top_level_comments:
|
||||
comments_by_msg.setdefault(comment.message_id, []).append(comment)
|
||||
|
||||
comments_by_message: dict[int, CommentListResponse] = {}
|
||||
for mid in unique_ids:
|
||||
msg = msg_map.get(mid)
|
||||
if msg is None:
|
||||
comments_by_message[mid] = CommentListResponse(comments=[], total_count=0)
|
||||
continue
|
||||
|
||||
ss_id = msg.thread.search_space_id
|
||||
user_perms = permissions_cache.get(ss_id, set())
|
||||
can_delete_any = has_permission(user_perms, Permission.COMMENTS_DELETE.value)
|
||||
|
||||
comment_responses = []
|
||||
for comment in comments_by_msg.get(mid, []):
|
||||
author = None
|
||||
if comment.author:
|
||||
author = AuthorResponse(
|
||||
id=comment.author.id,
|
||||
display_name=comment.author.display_name,
|
||||
avatar_url=comment.author.avatar_url,
|
||||
email=comment.author.email,
|
||||
)
|
||||
|
||||
replies = []
|
||||
for reply in sorted(comment.replies, key=lambda r: r.created_at):
|
||||
reply_author = None
|
||||
if reply.author:
|
||||
reply_author = AuthorResponse(
|
||||
id=reply.author.id,
|
||||
display_name=reply.author.display_name,
|
||||
avatar_url=reply.author.avatar_url,
|
||||
email=reply.author.email,
|
||||
)
|
||||
is_reply_author = (
|
||||
reply.author_id == user.id if reply.author_id else False
|
||||
)
|
||||
replies.append(
|
||||
CommentReplyResponse(
|
||||
id=reply.id,
|
||||
content=reply.content,
|
||||
content_rendered=render_mentions(reply.content, user_names),
|
||||
author=reply_author,
|
||||
created_at=reply.created_at,
|
||||
updated_at=reply.updated_at,
|
||||
is_edited=reply.updated_at > reply.created_at,
|
||||
can_edit=is_reply_author,
|
||||
can_delete=is_reply_author or can_delete_any,
|
||||
)
|
||||
)
|
||||
|
||||
is_comment_author = (
|
||||
comment.author_id == user.id if comment.author_id else False
|
||||
)
|
||||
comment_responses.append(
|
||||
CommentResponse(
|
||||
id=comment.id,
|
||||
message_id=comment.message_id,
|
||||
content=comment.content,
|
||||
content_rendered=render_mentions(comment.content, user_names),
|
||||
author=author,
|
||||
created_at=comment.created_at,
|
||||
updated_at=comment.updated_at,
|
||||
is_edited=comment.updated_at > comment.created_at,
|
||||
can_edit=is_comment_author,
|
||||
can_delete=is_comment_author or can_delete_any,
|
||||
reply_count=len(replies),
|
||||
replies=replies,
|
||||
)
|
||||
)
|
||||
|
||||
comments_by_message[mid] = CommentListResponse(
|
||||
comments=comment_responses,
|
||||
total_count=len(comment_responses),
|
||||
)
|
||||
|
||||
return CommentBatchResponse(comments_by_message=comments_by_message)
|
||||
|
||||
|
||||
async def create_comment(
|
||||
session: AsyncSession,
|
||||
message_id: int,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
|
@ -15,9 +16,11 @@ from app.db import (
|
|||
Document,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
async_session_maker,
|
||||
)
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
||||
class ConnectorService:
|
||||
|
|
@ -221,6 +224,7 @@ class ConnectorService:
|
|||
top_k: int = 20,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
query_embedding: list[float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Perform combined search using both chunk-based and document-based hybrid search,
|
||||
|
|
@ -246,34 +250,60 @@ class ConnectorService:
|
|||
Returns:
|
||||
List of combined and deduplicated document results
|
||||
"""
|
||||
from app.config import config
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# RRF constant
|
||||
k = 60
|
||||
|
||||
# Get more results from each retriever for better fusion
|
||||
retriever_top_k = top_k * 2
|
||||
|
||||
# IMPORTANT:
|
||||
# These retrievers share the same AsyncSession. AsyncSession does not permit
|
||||
# concurrent awaits that require DB IO on the same session/connection.
|
||||
# Running these in parallel can raise:
|
||||
# "This session is provisioning a new connection; concurrent operations are not permitted"
|
||||
#
|
||||
# So we run them sequentially.
|
||||
chunk_results = await self.chunk_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
# Reuse caller-provided embedding or compute once for both retrievers.
|
||||
if query_embedding is None:
|
||||
t_embed = time.perf_counter()
|
||||
query_embedding = config.embedding_model_instance.embed(query_text)
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf embedding in %.3fs type=%s",
|
||||
time.perf_counter() - t_embed,
|
||||
document_type,
|
||||
)
|
||||
|
||||
search_kwargs = {
|
||||
"query_text": query_text,
|
||||
"top_k": retriever_top_k,
|
||||
"search_space_id": search_space_id,
|
||||
"document_type": document_type,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"query_embedding": query_embedding,
|
||||
}
|
||||
|
||||
# Run chunk and document retrievers in parallel using separate DB sessions
|
||||
# so they don't contend on a shared AsyncSession connection.
|
||||
async def _run_chunk_search() -> list[dict[str, Any]]:
|
||||
async with async_session_maker() as session:
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
return await retriever.hybrid_search(**search_kwargs)
|
||||
|
||||
async def _run_doc_search() -> list[dict[str, Any]]:
|
||||
async with async_session_maker() as session:
|
||||
retriever = DocumentHybridSearchRetriever(session)
|
||||
return await retriever.hybrid_search(**search_kwargs)
|
||||
|
||||
t_parallel = time.perf_counter()
|
||||
chunk_results, doc_results = await asyncio.gather(
|
||||
_run_chunk_search(), _run_doc_search()
|
||||
)
|
||||
doc_results = await self.document_retriever.hybrid_search(
|
||||
query_text=query_text,
|
||||
top_k=retriever_top_k,
|
||||
search_space_id=search_space_id,
|
||||
document_type=document_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf parallel retrievers in %.3fs "
|
||||
"chunk_results=%d doc_results=%d type=%s",
|
||||
time.perf_counter() - t_parallel,
|
||||
len(chunk_results),
|
||||
len(doc_results),
|
||||
document_type,
|
||||
)
|
||||
|
||||
# Helper to extract document_id from our doc-grouped result
|
||||
|
|
@ -335,6 +365,13 @@ class ConnectorService:
|
|||
result["chunks"] = doc_data[did]["chunks"]
|
||||
combined_results.append(result)
|
||||
|
||||
perf.info(
|
||||
"[connector_svc] _combined_rrf_search TOTAL in %.3fs results=%d type=%s space=%d",
|
||||
time.perf_counter() - t0,
|
||||
len(combined_results),
|
||||
document_type,
|
||||
search_space_id,
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def _get_doc_url(self, metadata: dict[str, Any]) -> str:
|
||||
|
|
|
|||
|
|
@ -13,8 +13,10 @@ synchronous ChatLiteLLM-like interface and async methods.
|
|||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.exceptions import ContextOverflowError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
|
@ -26,6 +28,11 @@ from litellm.exceptions import (
|
|||
ContextWindowExceededError,
|
||||
)
|
||||
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
litellm.json_logs = False
|
||||
litellm.store_audit_logs = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
|
||||
|
|
@ -152,26 +159,95 @@ class LLMRouterService:
|
|||
# Merge with provided settings
|
||||
final_settings = {**default_settings, **instance._router_settings}
|
||||
|
||||
# Build a "auto-large" fallback group with deployments whose context
|
||||
# window exceeds the smallest deployment. This lets the router
|
||||
# automatically fall back to a bigger-context model when gpt-4o (128K)
|
||||
# hits ContextWindowExceededError.
|
||||
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
|
||||
|
||||
try:
|
||||
instance._router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy=final_settings.get(
|
||||
router_kwargs: dict[str, Any] = {
|
||||
"model_list": full_model_list,
|
||||
"routing_strategy": final_settings.get(
|
||||
"routing_strategy", "usage-based-routing"
|
||||
),
|
||||
num_retries=final_settings.get("num_retries", 3),
|
||||
allowed_fails=final_settings.get("allowed_fails", 3),
|
||||
cooldown_time=final_settings.get("cooldown_time", 60),
|
||||
set_verbose=False, # Disable verbose logging in production
|
||||
)
|
||||
"num_retries": final_settings.get("num_retries", 3),
|
||||
"allowed_fails": final_settings.get("allowed_fails", 3),
|
||||
"cooldown_time": final_settings.get("cooldown_time", 60),
|
||||
"set_verbose": False,
|
||||
}
|
||||
if ctx_fallbacks:
|
||||
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
|
||||
|
||||
instance._router = Router(**router_kwargs)
|
||||
instance._initialized = True
|
||||
logger.info(
|
||||
f"LLM Router initialized with {len(model_list)} deployments, "
|
||||
f"strategy: {final_settings.get('routing_strategy')}"
|
||||
"LLM Router initialized with %d deployments, "
|
||||
"strategy: %s, context_window_fallbacks: %s",
|
||||
len(model_list),
|
||||
final_settings.get("routing_strategy"),
|
||||
ctx_fallbacks or "none",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||
instance._router = None
|
||||
|
||||
@classmethod
|
||||
def _build_context_fallback_groups(
|
||||
cls, model_list: list[dict]
|
||||
) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
|
||||
"""Create an ``auto-large`` model group for context-window fallbacks.
|
||||
|
||||
Uses ``litellm.get_model_info`` to discover the context window of each
|
||||
deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
|
||||
window are duplicated into an ``auto-large`` group. The returned
|
||||
fallback config tells the Router: on ``ContextWindowExceededError`` for
|
||||
``auto``, retry with ``auto-large``.
|
||||
|
||||
Returns:
|
||||
(full_model_list, context_window_fallbacks) — ``full_model_list``
|
||||
contains the original entries plus any ``auto-large`` duplicates.
|
||||
``context_window_fallbacks`` is ``None`` when every deployment has
|
||||
the same context size (no useful fallback).
|
||||
"""
|
||||
from litellm import get_model_info
|
||||
|
||||
ctx_map: dict[str, int] = {}
|
||||
for dep in model_list:
|
||||
params = dep.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0:
|
||||
ctx_map[base_model] = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not ctx_map:
|
||||
return model_list, None
|
||||
|
||||
min_ctx = min(ctx_map.values())
|
||||
|
||||
large_deployments: list[dict] = []
|
||||
for dep in model_list:
|
||||
params = dep.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
if ctx_map.get(base_model, 0) > min_ctx:
|
||||
dup = {**dep, "model_name": "auto-large"}
|
||||
large_deployments.append(dup)
|
||||
|
||||
if not large_deployments:
|
||||
return model_list, None
|
||||
|
||||
logger.info(
|
||||
"Context-window fallback: %d large-context deployments "
|
||||
"(min_ctx=%d) added to 'auto-large' group",
|
||||
len(large_deployments),
|
||||
min_ctx,
|
||||
)
|
||||
return model_list + large_deployments, [{"auto": ["auto-large"]}]
|
||||
|
||||
@classmethod
|
||||
def _config_to_deployment(cls, config: dict) -> dict | None:
|
||||
"""
|
||||
|
|
@ -247,6 +323,48 @@ class LLMRouterService:
|
|||
return len(instance._model_list)
|
||||
|
||||
|
||||
_cached_context_profile: dict | None = None
|
||||
_cached_context_profile_computed: bool = False
|
||||
|
||||
# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
|
||||
_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
|
||||
|
||||
|
||||
def _get_cached_context_profile(router: Router) -> dict | None:
|
||||
"""Compute and cache the min context profile across all router deployments.
|
||||
|
||||
Called once on first ChatLiteLLMRouter creation; subsequent calls return
|
||||
the cached value. This avoids calling litellm.get_model_info() for every
|
||||
deployment on every request.
|
||||
"""
|
||||
global _cached_context_profile, _cached_context_profile_computed
|
||||
if _cached_context_profile_computed:
|
||||
return _cached_context_profile
|
||||
|
||||
from litellm import get_model_info
|
||||
|
||||
min_ctx: int | None = None
|
||||
for deployment in router.model_list:
|
||||
params = deployment.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
|
||||
min_ctx = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if min_ctx is not None:
|
||||
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
|
||||
_cached_context_profile = {"max_input_tokens": min_ctx}
|
||||
else:
|
||||
_cached_context_profile = None
|
||||
|
||||
_cached_context_profile_computed = True
|
||||
return _cached_context_profile
|
||||
|
||||
|
||||
class ChatLiteLLMRouter(BaseChatModel):
|
||||
"""
|
||||
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
|
||||
|
|
@ -257,6 +375,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
|
||||
window across all router deployments so that deepagents
|
||||
SummarizationMiddleware can use fraction-based triggers.
|
||||
|
||||
**Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
|
||||
directly — instances without bound tools are cached per streaming flag to
|
||||
avoid per-request re-initialization overhead and memory growth.
|
||||
"""
|
||||
|
||||
# Use model_config for Pydantic v2 compatibility
|
||||
|
|
@ -278,14 +400,6 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
tool_choice: str | dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the ChatLiteLLMRouter.
|
||||
|
||||
Args:
|
||||
router: LiteLLM Router instance. If None, uses the global singleton.
|
||||
bound_tools: Pre-bound tools for tool calling
|
||||
tool_choice: Tool choice configuration
|
||||
"""
|
||||
try:
|
||||
super().__init__(**kwargs)
|
||||
resolved_router = router or LLMRouterService.get_router()
|
||||
|
|
@ -297,51 +411,20 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
||||
)
|
||||
|
||||
# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
|
||||
computed_profile = self._compute_min_context_profile()
|
||||
computed_profile = _get_cached_context_profile(self._router)
|
||||
if computed_profile is not None:
|
||||
object.__setattr__(self, "profile", computed_profile)
|
||||
|
||||
logger.info(
|
||||
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
||||
logger.debug(
|
||||
"ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
|
||||
LLMRouterService.get_model_count(),
|
||||
self.streaming,
|
||||
bound_tools is not None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||
raise
|
||||
|
||||
def _compute_min_context_profile(self) -> dict | None:
|
||||
"""Derive a profile dict with max_input_tokens from router deployments.
|
||||
|
||||
Uses litellm.get_model_info to look up each deployment's context window
|
||||
and picks the *minimum* so that summarization triggers before ANY model
|
||||
in the pool overflows.
|
||||
"""
|
||||
from litellm import get_model_info
|
||||
|
||||
if not self._router:
|
||||
return None
|
||||
|
||||
min_ctx: int | None = None
|
||||
for deployment in self._router.model_list:
|
||||
params = deployment.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if (
|
||||
isinstance(ctx, int)
|
||||
and ctx > 0
|
||||
and (min_ctx is None or ctx < min_ctx)
|
||||
):
|
||||
min_ctx = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if min_ctx is not None:
|
||||
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
|
||||
return {"max_input_tokens": min_ctx}
|
||||
return None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "litellm-router"
|
||||
|
|
@ -410,6 +493,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
msg_count = len(messages)
|
||||
|
||||
# Convert LangChain messages to OpenAI format
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
|
|
@ -428,12 +515,30 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
perf.info(
|
||||
"[llm_router] _generate completed msgs=%d tools=%d in %.3fs",
|
||||
msg_count,
|
||||
len(self._bound_tools) if self._bound_tools else 0,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
generation = ChatGeneration(message=message)
|
||||
|
|
@ -453,6 +558,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
msg_count = len(messages)
|
||||
|
||||
# Convert LangChain messages to OpenAI format
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
|
|
@ -471,12 +580,30 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
perf.info(
|
||||
"[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs",
|
||||
msg_count,
|
||||
len(self._bound_tools) if self._bound_tools else 0,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
generation = ChatGeneration(message=message)
|
||||
|
|
@ -541,6 +668,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
if not self._router:
|
||||
raise ValueError("Router not initialized")
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
msg_count = len(messages)
|
||||
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
|
||||
# Add tools if bound
|
||||
|
|
@ -559,20 +690,52 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
**call_kwargs,
|
||||
)
|
||||
except ContextWindowExceededError as e:
|
||||
perf.warning(
|
||||
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
perf.warning(
|
||||
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Yield chunks asynchronously
|
||||
t_first_chunk = time.perf_counter()
|
||||
perf.info(
|
||||
"[llm_router] _astream connection established msgs=%d in %.3fs",
|
||||
msg_count,
|
||||
t_first_chunk - t0,
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
first_chunk_logged = False
|
||||
async for chunk in response:
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
chunk_msg = self._convert_delta_to_chunk(delta)
|
||||
if chunk_msg:
|
||||
chunk_count += 1
|
||||
if not first_chunk_logged:
|
||||
perf.info(
|
||||
"[llm_router] _astream first chunk in %.3fs (total %.3fs from start)",
|
||||
time.perf_counter() - t_first_chunk,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
first_chunk_logged = True
|
||||
yield ChatGenerationChunk(message=chunk_msg)
|
||||
|
||||
perf.info(
|
||||
"[llm_router] _astream completed chunks=%d total=%.3fs",
|
||||
chunk_count,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
|
||||
"""Convert LangChain messages to OpenAI format."""
|
||||
from langchain_core.messages import (
|
||||
|
|
@ -687,19 +850,28 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
return None
|
||||
|
||||
|
||||
def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Get a ChatLiteLLMRouter instance for auto mode.
|
||||
def get_auto_mode_llm(
|
||||
*,
|
||||
streaming: bool = True,
|
||||
) -> ChatLiteLLMRouter | None:
|
||||
"""Return a cached ChatLiteLLMRouter for auto mode.
|
||||
|
||||
Returns:
|
||||
ChatLiteLLMRouter instance or None if router not initialized
|
||||
Base (no tools) instances are cached per ``streaming`` flag so we
|
||||
avoid re-constructing them on every request. ``bind_tools()`` still
|
||||
returns a fresh instance because bound tools differ per agent.
|
||||
"""
|
||||
if not LLMRouterService.is_initialized():
|
||||
logger.warning("LLM Router not initialized for auto mode")
|
||||
return None
|
||||
|
||||
cached = _router_instance_cache.get(streaming)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
return ChatLiteLLMRouter()
|
||||
instance = ChatLiteLLMRouter(streaming=streaming)
|
||||
_router_instance_cache[streaming] = instance
|
||||
return instance
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -12,12 +12,20 @@ from app.services.llm_router_service import (
|
|||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
litellm.drop_params = True
|
||||
|
||||
# Memory controls: prevent unbounded internal accumulation
|
||||
litellm.telemetry = False
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm.input_callback = []
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -221,7 +229,7 @@ async def get_search_space_llm_instance(
|
|||
logger.debug(
|
||||
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
|
||||
)
|
||||
return ChatLiteLLMRouter(disable_streaming=disable_streaming)
|
||||
return get_auto_mode_llm(streaming=not disable_streaming)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1 +1,28 @@
|
|||
"""Celery tasks package."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.config import config
|
||||
|
||||
_celery_engine = None
|
||||
_celery_session_maker = None
|
||||
|
||||
|
||||
def get_celery_session_maker() -> async_sessionmaker:
|
||||
"""Return a shared async session maker for Celery tasks.
|
||||
|
||||
A single NullPool engine is created per worker process and reused
|
||||
across all task invocations to avoid leaking engine objects.
|
||||
"""
|
||||
global _celery_engine, _celery_session_maker
|
||||
if _celery_session_maker is None:
|
||||
_celery_engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
_celery_session_maker = async_sessionmaker(
|
||||
_celery_engine, expire_on_commit=False
|
||||
)
|
||||
return _celery_session_maker
|
||||
|
|
|
|||
|
|
@ -3,11 +3,8 @@
|
|||
import logging
|
||||
import traceback
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -42,20 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
|
|||
)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
Create a new async session maker for Celery tasks.
|
||||
This is necessary because Celery tasks run in a new event loop,
|
||||
and the default session maker is bound to the main app's event loop.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool, # Don't use connection pooling for Celery tasks
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="index_slack_messages", bind=True)
|
||||
def index_slack_messages_task(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -4,30 +4,18 @@ import logging
|
|||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Document
|
||||
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""Create async session maker for Celery tasks."""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="reindex_document", bind=True)
|
||||
def reindex_document_task(self, document_id: int, user_id: str):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,13 +5,11 @@ import logging
|
|||
import os
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.document_processors import (
|
||||
add_extension_received_document,
|
||||
add_youtube_video_document,
|
||||
|
|
@ -91,20 +89,6 @@ async def _run_heartbeat_loop(notification_id: int):
|
|||
pass # Normal cancellation when task completes
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
Create a new async session maker for Celery tasks.
|
||||
This is necessary because Celery tasks run in a new event loop,
|
||||
and the default session maker is bound to the main app's event loop.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool, # Don't use connection pooling for Celery tasks
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="process_extension_document", bind=True)
|
||||
def process_extension_document_task(
|
||||
self, individual_document_dict, search_space_id: int, user_id: str
|
||||
|
|
|
|||
|
|
@ -5,14 +5,13 @@ import logging
|
|||
import sys
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,20 +24,6 @@ if sys.platform.startswith("win"):
|
|||
)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""
|
||||
Create a new async session maker for Celery tasks.
|
||||
This is necessary because Celery tasks run in a new event loop,
|
||||
and the default session maker is bound to the main app's event loop.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool, # Don't use connection pooling for Celery tasks
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Content-based podcast generation (for new-chat)
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -3,28 +3,16 @@
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.utils.indexing_locks import is_connector_indexing_locked
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""Create async session maker for Celery tasks."""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="check_periodic_schedules")
|
||||
def check_periodic_schedules_task():
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -29,20 +29,17 @@ from datetime import UTC, datetime
|
|||
|
||||
import redis
|
||||
from sqlalchemy import and_, or_, text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Document, DocumentStatus, Notification
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client for checking heartbeats
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
# Error messages shown to users when tasks are interrupted
|
||||
STALE_SYNC_ERROR_MESSAGE = "Sync was interrupted unexpectedly. Please retry."
|
||||
STALE_PROCESSING_ERROR_MESSAGE = "Syncing was interrupted unexpectedly. Please retry."
|
||||
|
||||
|
|
@ -60,16 +57,6 @@ def _get_heartbeat_key(notification_id: int) -> str:
|
|||
return f"indexing:heartbeat:{notification_id}"
|
||||
|
||||
|
||||
def get_celery_session_maker():
|
||||
"""Create async session maker for Celery tasks."""
|
||||
engine = create_async_engine(
|
||||
config.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(name="cleanup_stale_indexing_notifications")
|
||||
def cleanup_stale_indexing_notifications_task():
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ Supports loading LLM configurations from:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -19,9 +21,9 @@ from dataclasses import dataclass, field
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import anyio
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
|
@ -47,6 +49,7 @@ from app.db import (
|
|||
SearchSourceConnectorType,
|
||||
SurfsenseDocsDocument,
|
||||
async_session_maker,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.chat_session_state_service import (
|
||||
|
|
@ -56,14 +59,9 @@ from app.services.chat_session_state_service import (
|
|||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
|
||||
_perf_log = logging.getLogger("surfsense.perf")
|
||||
_perf_log.setLevel(logging.DEBUG)
|
||||
if not _perf_log.handlers:
|
||||
_h = logging.StreamHandler()
|
||||
_h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
|
||||
_perf_log.addHandler(_h)
|
||||
_perf_log.propagate = False
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
|
@ -1016,7 +1014,6 @@ async def stream_new_chat(
|
|||
user_query: str,
|
||||
search_space_id: int,
|
||||
chat_id: int,
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
|
|
@ -1032,11 +1029,13 @@ async def stream_new_chat(
|
|||
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
|
||||
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
|
||||
|
||||
The function creates and manages its own database session to guarantee proper
|
||||
cleanup even when Starlette's middleware cancels the task on client disconnect.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID
|
||||
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
||||
session: The database session
|
||||
user_id: The current user's UUID string (for memory tools and session state)
|
||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
|
||||
|
|
@ -1050,7 +1049,9 @@ async def stream_new_chat(
|
|||
streaming_service = VercelStreamingService()
|
||||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
log_system_snapshot("stream_new_chat_START")
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
# Mark AI as responding to this user for live collaboration
|
||||
if user_id:
|
||||
|
|
@ -1286,6 +1287,12 @@ async def stream_new_chat(
|
|||
# short-lived transactions (or use isolated sessions).
|
||||
await session.commit()
|
||||
|
||||
# Detach heavy ORM objects (documents with chunks, reports, etc.)
|
||||
# from the session identity map now that we've extracted the data
|
||||
# we need. This prevents them from accumulating in memory for the
|
||||
# entire duration of LLM streaming (which can be several minutes).
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
|
|
@ -1353,6 +1360,12 @@ async def stream_new_chat(
|
|||
items=initial_items,
|
||||
)
|
||||
|
||||
# These ORM objects (with eagerly-loaded chunks) can be very large.
|
||||
# They're only needed to build context strings already copied into
|
||||
# final_query / langchain_messages — release them before streaming.
|
||||
del mentioned_documents, mentioned_surfsense_docs, recent_reports
|
||||
del langchain_messages, final_query
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
|
|
@ -1382,6 +1395,7 @@ async def stream_new_chat(
|
|||
time.perf_counter() - _t_stream_start,
|
||||
chat_id,
|
||||
)
|
||||
log_system_snapshot("stream_new_chat_END")
|
||||
|
||||
if stream_result.is_interrupted:
|
||||
yield streaming_service.format_finish_step()
|
||||
|
|
@ -1461,30 +1475,57 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
|
||||
finally:
|
||||
# Clear AI responding state for live collaboration.
|
||||
# The original session may be broken (client disconnect / CancelledError
|
||||
# can corrupt the underlying DB connection), so we try a rollback first
|
||||
# and fall back to a fresh session if the original is unusable.
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
# Shield the ENTIRE async cleanup from anyio cancel-scope
|
||||
# cancellation. Starlette's BaseHTTPMiddleware uses anyio task
|
||||
# groups; on client disconnect, it cancels the scope with
|
||||
# level-triggered cancellation — every unshielded `await` inside
|
||||
# the cancelled scope raises CancelledError immediately. Without
|
||||
# this shield the very first `await` (session.rollback) would
|
||||
# raise CancelledError, `except Exception` wouldn't catch it
|
||||
# (CancelledError is a BaseException), and the rest of the
|
||||
# finally block — including session.close() — would never run.
|
||||
with anyio.CancelScope(shield=True):
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
agent = llm = connector_service = sandbox_backend = None
|
||||
input_state = stream_result = None
|
||||
session = None
|
||||
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
if collected:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||
collected,
|
||||
chat_id,
|
||||
)
|
||||
trim_native_heap()
|
||||
log_system_snapshot("stream_new_chat_END")
|
||||
|
||||
|
||||
async def stream_resume_chat(
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
decisions: list[dict],
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
|
|
@ -1493,6 +1534,7 @@ async def stream_resume_chat(
|
|||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
if user_id:
|
||||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
|
|
@ -1589,6 +1631,7 @@ async def stream_resume_chat(
|
|||
|
||||
# Release the transaction before streaming (same rationale as stream_new_chat).
|
||||
await session.commit()
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
|
|
@ -1652,16 +1695,37 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
|
||||
finally:
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
with anyio.CancelScope(shield=True):
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
agent = llm = connector_service = sandbox_backend = None
|
||||
stream_result = None
|
||||
session = None
|
||||
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
if collected:
|
||||
_perf_log.info(
|
||||
"[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
|
||||
collected,
|
||||
chat_id,
|
||||
)
|
||||
trim_native_heap()
|
||||
log_system_snapshot("stream_resume_chat_END")
|
||||
|
|
|
|||
174
surfsense_backend/app/utils/perf.py
Normal file
174
surfsense_backend/app/utils/perf.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
Centralized performance monitoring for SurfSense backend.
|
||||
|
||||
Provides:
|
||||
- A shared [PERF] logger used across all modules
|
||||
- perf_timer context manager for timing code blocks
|
||||
- perf_async_timer for async code blocks
|
||||
- system_snapshot() for CPU/memory profiling
|
||||
- RequestPerfMiddleware for per-request timing
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import Any
|
||||
|
||||
_perf_log: logging.Logger | None = None
|
||||
_last_rss_mb: float = 0.0
|
||||
|
||||
|
||||
def get_perf_logger() -> logging.Logger:
|
||||
"""Return the singleton [PERF] logger, creating it once on first call."""
|
||||
global _perf_log
|
||||
if _perf_log is None:
|
||||
_perf_log = logging.getLogger("surfsense.perf")
|
||||
_perf_log.setLevel(logging.DEBUG)
|
||||
if not _perf_log.handlers:
|
||||
h = logging.StreamHandler()
|
||||
h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
|
||||
_perf_log.addHandler(h)
|
||||
_perf_log.propagate = False
|
||||
return _perf_log
|
||||
|
||||
|
||||
@contextmanager
|
||||
def perf_timer(label: str, *, extra: dict[str, Any] | None = None):
|
||||
"""Synchronous context manager that logs elapsed wall-clock time.
|
||||
|
||||
Usage:
|
||||
with perf_timer("[my_func] heavy computation"):
|
||||
...
|
||||
"""
|
||||
log = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
yield
|
||||
elapsed = time.perf_counter() - t0
|
||||
suffix = ""
|
||||
if extra:
|
||||
suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items())
|
||||
log.info("%s in %.3fs%s", label, elapsed, suffix)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def perf_async_timer(label: str, *, extra: dict[str, Any] | None = None):
|
||||
"""Async context manager that logs elapsed wall-clock time.
|
||||
|
||||
Usage:
|
||||
async with perf_async_timer("[search] vector search"):
|
||||
...
|
||||
"""
|
||||
log = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
yield
|
||||
elapsed = time.perf_counter() - t0
|
||||
suffix = ""
|
||||
if extra:
|
||||
suffix = " " + " ".join(f"{k}={v}" for k, v in extra.items())
|
||||
log.info("%s in %.3fs%s", label, elapsed, suffix)
|
||||
|
||||
|
||||
def system_snapshot() -> dict[str, Any]:
|
||||
"""Capture a lightweight CPU + memory snapshot of the current process.
|
||||
|
||||
Returns a dict with:
|
||||
- rss_mb: Resident Set Size in MB
|
||||
- rss_delta_mb: Change in RSS since the last snapshot
|
||||
- cpu_percent: CPU usage % since last call (per-process)
|
||||
- threads: number of active threads
|
||||
- open_fds: number of open file descriptors (Linux only)
|
||||
- asyncio_tasks: number of asyncio tasks currently alive
|
||||
- gc_counts: tuple of object counts per gc generation
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
global _last_rss_mb
|
||||
|
||||
snapshot: dict[str, Any] = {}
|
||||
try:
|
||||
import psutil
|
||||
|
||||
proc = psutil.Process(os.getpid())
|
||||
mem = proc.memory_info()
|
||||
rss_mb = round(mem.rss / 1024 / 1024, 1)
|
||||
snapshot["rss_mb"] = rss_mb
|
||||
snapshot["rss_delta_mb"] = (
|
||||
round(rss_mb - _last_rss_mb, 1) if _last_rss_mb else 0.0
|
||||
)
|
||||
_last_rss_mb = rss_mb
|
||||
snapshot["cpu_percent"] = proc.cpu_percent(interval=None)
|
||||
snapshot["threads"] = proc.num_threads()
|
||||
try:
|
||||
snapshot["open_fds"] = proc.num_fds()
|
||||
except AttributeError:
|
||||
snapshot["open_fds"] = -1
|
||||
except ImportError:
|
||||
snapshot["rss_mb"] = -1
|
||||
snapshot["rss_delta_mb"] = 0.0
|
||||
snapshot["cpu_percent"] = -1
|
||||
snapshot["threads"] = -1
|
||||
snapshot["open_fds"] = -1
|
||||
|
||||
try:
|
||||
all_tasks = asyncio.all_tasks()
|
||||
snapshot["asyncio_tasks"] = len(all_tasks)
|
||||
except RuntimeError:
|
||||
snapshot["asyncio_tasks"] = -1
|
||||
|
||||
snapshot["gc_counts"] = gc.get_count()
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
def log_system_snapshot(label: str = "system_snapshot") -> None:
|
||||
"""Capture and log a system snapshot with memory delta tracking."""
|
||||
snap = system_snapshot()
|
||||
delta_str = ""
|
||||
if snap["rss_delta_mb"]:
|
||||
sign = "+" if snap["rss_delta_mb"] > 0 else ""
|
||||
delta_str = f" delta={sign}{snap['rss_delta_mb']}MB"
|
||||
get_perf_logger().info(
|
||||
"[%s] rss=%.1fMB%s cpu=%.1f%% threads=%d fds=%d asyncio_tasks=%d gc=%s",
|
||||
label,
|
||||
snap["rss_mb"],
|
||||
delta_str,
|
||||
snap["cpu_percent"],
|
||||
snap["threads"],
|
||||
snap["open_fds"],
|
||||
snap["asyncio_tasks"],
|
||||
snap["gc_counts"],
|
||||
)
|
||||
|
||||
if snap["rss_mb"] > 0 and snap["rss_delta_mb"] > 500:
|
||||
get_perf_logger().warning(
|
||||
"[MEMORY_SPIKE] %s: RSS jumped by %.1fMB (now %.1fMB). "
|
||||
"Possible leak — check recent operations.",
|
||||
label,
|
||||
snap["rss_delta_mb"],
|
||||
snap["rss_mb"],
|
||||
)
|
||||
|
||||
|
||||
def trim_native_heap() -> bool:
|
||||
"""Ask glibc to return free heap pages to the OS via ``malloc_trim(0)``.
|
||||
|
||||
On Linux (glibc), ``free()`` does not release memory back to the OS if
|
||||
it sits below the brk watermark. ``malloc_trim`` forces the allocator
|
||||
to give back as many freed pages as possible.
|
||||
|
||||
Returns True if trimming was performed, False otherwise (non-Linux or
|
||||
libc unavailable).
|
||||
"""
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
if sys.platform != "linux":
|
||||
return False
|
||||
try:
|
||||
libc = ctypes.CDLL("libc.so.6")
|
||||
libc.malloc_trim(0)
|
||||
return True
|
||||
except (OSError, AttributeError):
|
||||
return False
|
||||
|
|
@ -181,7 +181,7 @@ export default function MorePagesPage() {
|
|||
</DialogHeader>
|
||||
<div className="flex flex-col items-center gap-4 py-4">
|
||||
<Link
|
||||
href="https://calendly.com/eric-surfsense/surfsense-meeting"
|
||||
href="https://cal.com/mod-rohan"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="flex w-full items-center justify-center gap-2 rounded-lg bg-primary px-4 py-2.5 text-sm font-medium text-primary-foreground transition hover:bg-primary/90"
|
||||
|
|
@ -195,11 +195,11 @@ export default function MorePagesPage() {
|
|||
<span className="h-px w-8 bg-border" />
|
||||
</div>
|
||||
<Link
|
||||
href="mailto:eric@surfsense.com"
|
||||
href="mailto:rohan@surfsense.com"
|
||||
className="flex items-center gap-2 text-sm text-muted-foreground transition hover:text-foreground"
|
||||
>
|
||||
<IconMailFilled className="h-4 w-4" />
|
||||
eric@surfsense.com
|
||||
rohan@surfsense.com
|
||||
</Link>
|
||||
</div>
|
||||
</DialogContent>
|
||||
|
|
|
|||
|
|
@ -235,4 +235,3 @@ button {
|
|||
@source '../node_modules/streamdown/dist/*.js';
|
||||
@source '../node_modules/@streamdown/code/dist/*.js';
|
||||
@source '../node_modules/@streamdown/math/dist/*.js';
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import { AlertTriangle, Cable, Settings } from "lucide-react";
|
|||
import Link from "next/link";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import type { FC } from "react";
|
||||
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
|
||||
import {
|
||||
globalNewLLMConfigsAtom,
|
||||
llmPreferencesAtom,
|
||||
|
|
@ -19,7 +20,6 @@ import { Spinner } from "@/components/ui/spinner";
|
|||
import { Tabs, TabsContent } from "@/components/ui/tabs";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
import { useConnectorsElectric } from "@/hooks/use-connectors-electric";
|
||||
import { useDocuments } from "@/hooks/use-documents";
|
||||
import { useInbox } from "@/hooks/use-inbox";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ConnectorDialogHeader } from "./connector-popup/components/connector-dialog-header";
|
||||
|
|
@ -62,10 +62,9 @@ export const ConnectorIndicator: FC<{ hideTrigger?: boolean }> = ({ hideTrigger
|
|||
|
||||
const llmConfigLoading = preferencesLoading || globalConfigsLoading;
|
||||
|
||||
// Fetch document type counts using Electric SQL + PGlite for real-time updates
|
||||
const { typeCounts: documentTypeCounts, loading: documentTypesLoading } = useDocuments(
|
||||
searchSpaceId ? Number(searchSpaceId) : null
|
||||
);
|
||||
// Fetch document type counts via the lightweight /type-counts endpoint (cached 10 min)
|
||||
const { data: documentTypeCounts, isFetching: documentTypesLoading } =
|
||||
useAtomValue(documentTypeCountsAtom);
|
||||
|
||||
// Fetch notifications to detect indexing failures
|
||||
const { inboxItems = [] } = useInbox(
|
||||
|
|
|
|||
|
|
@ -2,27 +2,28 @@ import { EnumConnectorName } from "@/contracts/enums/connector";
|
|||
|
||||
// OAuth Connectors (Quick Connect)
|
||||
export const OAUTH_CONNECTORS = [
|
||||
{
|
||||
id: "google-drive-connector",
|
||||
title: "Google Drive",
|
||||
description: "Search your Drive files",
|
||||
connectorType: EnumConnectorName.GOOGLE_DRIVE_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/google/drive/connector/add/",
|
||||
},
|
||||
{
|
||||
id: "google-gmail-connector",
|
||||
title: "Gmail",
|
||||
description: "Search through your emails",
|
||||
connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/google/gmail/connector/add/",
|
||||
},
|
||||
{
|
||||
id: "google-calendar-connector",
|
||||
title: "Google Calendar",
|
||||
description: "Search through your events",
|
||||
connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/google/calendar/connector/add/",
|
||||
},
|
||||
// // Uncomment for managed Google Connections
|
||||
// {
|
||||
// id: "google-drive-connector",
|
||||
// title: "Google Drive",
|
||||
// description: "Search your Drive files",
|
||||
// connectorType: EnumConnectorName.GOOGLE_DRIVE_CONNECTOR,
|
||||
// authEndpoint: "/api/v1/auth/google/drive/connector/add/",
|
||||
// },
|
||||
// {
|
||||
// id: "google-gmail-connector",
|
||||
// title: "Gmail",
|
||||
// description: "Search through your emails",
|
||||
// connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
|
||||
// authEndpoint: "/api/v1/auth/google/gmail/connector/add/",
|
||||
// },
|
||||
// {
|
||||
// id: "google-calendar-connector",
|
||||
// title: "Google Calendar",
|
||||
// description: "Search through your events",
|
||||
// connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
|
||||
// authEndpoint: "/api/v1/auth/google/calendar/connector/add/",
|
||||
// },
|
||||
{
|
||||
id: "airtable-connector",
|
||||
title: "Airtable",
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
|
|||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import type { Document } from "@/contracts/types/document.types";
|
||||
import { useBatchCommentsPreload } from "@/hooks/use-comments";
|
||||
import { useCommentsElectric } from "@/hooks/use-comments-electric";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -309,6 +310,22 @@ const Composer: FC = () => {
|
|||
// Sync comments for the entire thread via Electric SQL (one subscription per thread)
|
||||
useCommentsElectric(threadId);
|
||||
|
||||
// Batch-prefetch comments for all assistant messages so individual useComments
|
||||
// hooks never fire their own network requests (eliminates N+1 API calls).
|
||||
// Return a primitive string from the selector so useSyncExternalStore can
|
||||
// compare snapshots by value and avoid infinite re-render loops.
|
||||
const assistantIdsKey = useAssistantState(({ thread }) =>
|
||||
thread.messages
|
||||
.filter((m) => m.role === "assistant" && m.id?.startsWith("msg-"))
|
||||
.map((m) => m.id!.replace("msg-", ""))
|
||||
.join(",")
|
||||
);
|
||||
const assistantDbMessageIds = useMemo(
|
||||
() => (assistantIdsKey ? assistantIdsKey.split(",").map(Number) : []),
|
||||
[assistantIdsKey]
|
||||
);
|
||||
useBatchCommentsPreload(assistantDbMessageIds);
|
||||
|
||||
// Auto-focus editor on new chat page after mount
|
||||
useEffect(() => {
|
||||
if (isThreadEmpty && !hasAutoFocusedRef.current && editorRef.current) {
|
||||
|
|
|
|||
|
|
@ -23,12 +23,12 @@ export function ContactFormGridWithDetails() {
|
|||
We'd love to hear from you!
|
||||
</p>
|
||||
<p className="mt-4 max-w-lg text-center text-base text-neutral-600 dark:text-neutral-400">
|
||||
Schedule a meeting with our Head of Product, Eric Lammertsma, or send us an email.
|
||||
Schedule a meeting with us, or send us an email.
|
||||
</p>
|
||||
|
||||
<div className="mt-10 flex flex-col items-center gap-6">
|
||||
<Link
|
||||
href="https://calendly.com/eric-surfsense/surfsense-meeting"
|
||||
href="https://cal.com/mod-rohan"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="flex items-center gap-3 rounded-xl bg-gradient-to-b from-blue-500 to-blue-600 px-6 py-3 text-base font-medium text-white shadow-lg transition duration-200 hover:from-blue-600 hover:to-blue-700"
|
||||
|
|
@ -44,11 +44,11 @@ export function ContactFormGridWithDetails() {
|
|||
</div>
|
||||
|
||||
<Link
|
||||
href="mailto:eric@surfsense.com"
|
||||
href="mailto:rohan@surfsense.com"
|
||||
className="flex items-center gap-2 text-base text-neutral-600 transition duration-200 hover:text-neutral-900 dark:text-neutral-400 dark:hover:text-neutral-200"
|
||||
>
|
||||
<IconMailFilled className="h-5 w-5" />
|
||||
eric@surfsense.com
|
||||
rohan@surfsense.com
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
|
|
|
|||
|
|
@ -96,12 +96,9 @@ export function HeroSection() {
|
|||
</div>
|
||||
)}
|
||||
</h2>
|
||||
{/* // TODO:aCTUAL DESCRITION */}
|
||||
<p className="relative z-50 mx-auto mt-4 max-w-xl px-4 text-center text-base/6 text-gray-600 dark:text-gray-200">
|
||||
Connect any AI to your documents, Drive, Notion and more,
|
||||
</p>
|
||||
<p className="relative z-50 mx-auto mt-0 max-w-xl px-4 text-center text-base/6 text-gray-600 dark:text-gray-200">
|
||||
then chat with it, generate podcasts and reports, or even invite your team.
|
||||
<p className="relative z-50 mx-auto mt-4 max-w-lg px-6 text-center text-sm leading-relaxed text-gray-600 sm:text-base sm:leading-relaxed md:max-w-xl md:text-lg md:leading-relaxed dark:text-gray-200">
|
||||
Connect any LLM to your internal knowledge sources and chat with it in real time alongside
|
||||
your team.
|
||||
</p>
|
||||
<div className="mb-6 mt-6 flex w-full flex-col items-center justify-center gap-4 px-8 sm:flex-row md:mb-10">
|
||||
<GetStartedButton />
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import {
|
|||
IconBrandGithub,
|
||||
IconBrandReddit,
|
||||
IconMenu2,
|
||||
IconSpeakerphone,
|
||||
IconX,
|
||||
} from "@tabler/icons-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
|
|
@ -13,7 +12,6 @@ import { useEffect, useState } from "react";
|
|||
import { SignInButton } from "@/components/auth/sign-in-button";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { ThemeTogglerComponent } from "@/components/theme/theme-toggle";
|
||||
import { useAnnouncements } from "@/hooks/use-announcements";
|
||||
import { useGithubStars } from "@/hooks/use-github-stars";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
|
|
@ -49,11 +47,7 @@ export const Navbar = () => {
|
|||
|
||||
const DesktopNav = ({ navItems, isScrolled }: any) => {
|
||||
const [hovered, setHovered] = useState<number | null>(null);
|
||||
const [mounted, setMounted] = useState(false);
|
||||
const { compactFormat: githubStars, loading: loadingGithubStars } = useGithubStars();
|
||||
const { unreadCount } = useAnnouncements();
|
||||
|
||||
useEffect(() => setMounted(true), []);
|
||||
return (
|
||||
<motion.div
|
||||
onMouseLeave={() => {
|
||||
|
|
@ -124,17 +118,6 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
|
|||
</span>
|
||||
)}
|
||||
</Link>
|
||||
<Link
|
||||
href="/announcements"
|
||||
className="relative hidden rounded-full p-2 hover:bg-gray-100 dark:hover:bg-neutral-800 transition-colors md:flex items-center justify-center"
|
||||
>
|
||||
<IconSpeakerphone className="h-5 w-5 text-neutral-600 dark:text-neutral-300" />
|
||||
{mounted && unreadCount > 0 && (
|
||||
<span className="absolute -top-0.5 -right-0.5 flex h-4 min-w-4 items-center justify-center rounded-full bg-red-500 px-1 text-[10px] font-bold text-white">
|
||||
{unreadCount > 99 ? "99+" : unreadCount}
|
||||
</span>
|
||||
)}
|
||||
</Link>
|
||||
<ThemeTogglerComponent />
|
||||
<SignInButton variant="desktop" />
|
||||
</div>
|
||||
|
|
@ -144,11 +127,7 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
|
|||
|
||||
const MobileNav = ({ navItems, isScrolled }: any) => {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [mounted, setMounted] = useState(false);
|
||||
const { compactFormat: githubStars, loading: loadingGithubStars } = useGithubStars();
|
||||
const { unreadCount } = useAnnouncements();
|
||||
|
||||
useEffect(() => setMounted(true), []);
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
|
|
@ -233,17 +212,6 @@ const MobileNav = ({ navItems, isScrolled }: any) => {
|
|||
</span>
|
||||
)}
|
||||
</Link>
|
||||
<Link
|
||||
href="/announcements"
|
||||
className="relative flex items-center justify-center rounded-lg p-2 hover:bg-gray-100 dark:hover:bg-neutral-800 transition-colors touch-manipulation"
|
||||
>
|
||||
<IconSpeakerphone className="h-5 w-5 text-neutral-600 dark:text-neutral-300" />
|
||||
{mounted && unreadCount > 0 && (
|
||||
<span className="absolute -top-0.5 -right-0.5 flex h-4 min-w-4 items-center justify-center rounded-full bg-red-500 px-1 text-[10px] font-bold text-white">
|
||||
{unreadCount > 99 ? "99+" : unreadCount}
|
||||
</span>
|
||||
)}
|
||||
</Link>
|
||||
<ThemeTogglerComponent />
|
||||
</div>
|
||||
<SignInButton variant="mobile" />
|
||||
|
|
|
|||
|
|
@ -8,43 +8,34 @@ const demoPlans = [
|
|||
price: "0",
|
||||
yearlyPrice: "0",
|
||||
period: "",
|
||||
billingText: "Includes 30 day PRO trial",
|
||||
billingText: "",
|
||||
features: [
|
||||
"Open source on GitHub",
|
||||
"Self Hostable",
|
||||
"Upload and chat with 300+ pages of content",
|
||||
"Connects with 8 popular sources, like Drive and Notion",
|
||||
"Includes limited access to ChatGPT, Claude, and DeepSeek models",
|
||||
"Supports 100+ more LLMs, including Gemini, Llama and many more",
|
||||
"50+ File extensions supported",
|
||||
"Generate podcasts in seconds",
|
||||
"Cross-Browser Extension for dynamic webpages including authenticated content",
|
||||
"Includes access to ChatGPT text and audio models",
|
||||
"Realtime Collaborative Group Chats with teammates",
|
||||
"Community support on Discord",
|
||||
],
|
||||
description: "Powerful features with some limitations",
|
||||
description: "",
|
||||
buttonText: "Get Started",
|
||||
href: "/",
|
||||
href: "/login",
|
||||
isPopular: false,
|
||||
},
|
||||
{
|
||||
name: "PRO",
|
||||
price: "10",
|
||||
yearlyPrice: "10",
|
||||
period: "user / month",
|
||||
billingText: "billed annually",
|
||||
price: "0",
|
||||
yearlyPrice: "0",
|
||||
period: "",
|
||||
billingText: "Free during beta",
|
||||
features: [
|
||||
"Everything in Free",
|
||||
"Upload and chat with 5,000+ pages of content per user",
|
||||
"Connects with 15+ external sources, like Slack and Airtable",
|
||||
"Includes extended access to ChatGPT, Claude, and DeepSeek models",
|
||||
"Collaboration and commenting features",
|
||||
"Shared BYOK (Bring Your Own Key)",
|
||||
"Team and role management",
|
||||
"Planned: Centralized billing",
|
||||
"Priority support",
|
||||
"Includes 6000+ pages of content",
|
||||
"Access to more models and providers",
|
||||
"Priority support on Discord",
|
||||
],
|
||||
description: "The AI knowledge base for individuals and teams",
|
||||
buttonText: "Upgrade",
|
||||
href: "/contact",
|
||||
description: "",
|
||||
buttonText: "Get Started",
|
||||
href: "/login",
|
||||
isPopular: true,
|
||||
},
|
||||
{
|
||||
|
|
@ -55,12 +46,9 @@ const demoPlans = [
|
|||
billingText: "",
|
||||
features: [
|
||||
"Everything in Pro",
|
||||
"Connect and chat with virtually unlimited pages of content",
|
||||
"Limit models and/or providers",
|
||||
"On-prem or VPC deployment",
|
||||
"Planned: Audit logs and compliance",
|
||||
"Planned: SSO, OIDC & SAML",
|
||||
"Planned: Role-based access control (RBAC)",
|
||||
"Audit logs and compliance",
|
||||
"SSO, OIDC & SAML",
|
||||
"White-glove setup and deployment",
|
||||
"Monthly managed updates and maintenance",
|
||||
"SLA commitments",
|
||||
|
|
|
|||
|
|
@ -111,8 +111,8 @@ const FILE_TYPE_CONFIG: Record<string, Record<string, string[]>> = {
|
|||
|
||||
const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5";
|
||||
|
||||
// Upload limits
|
||||
const MAX_FILES = 10;
|
||||
// Upload limits — files are sent in batches of 5 to avoid proxy timeouts
|
||||
const MAX_FILES = 50;
|
||||
const MAX_TOTAL_SIZE_MB = 200;
|
||||
const MAX_TOTAL_SIZE_BYTES = MAX_TOTAL_SIZE_MB * 1024 * 1024;
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,22 @@ export const getCommentsResponse = z.object({
|
|||
total_count: z.number(),
|
||||
});
|
||||
|
||||
/**
|
||||
* Batch-fetch comments for multiple messages
|
||||
*/
|
||||
export const getBatchCommentsRequest = z.object({
|
||||
message_ids: z.array(z.number()).min(1).max(200),
|
||||
});
|
||||
|
||||
export const commentListResponse = z.object({
|
||||
comments: z.array(comment),
|
||||
total_count: z.number(),
|
||||
});
|
||||
|
||||
export const getBatchCommentsResponse = z.object({
|
||||
comments_by_message: z.record(z.string(), commentListResponse),
|
||||
});
|
||||
|
||||
/**
|
||||
* Create comment
|
||||
*/
|
||||
|
|
@ -145,6 +161,8 @@ export type MentionComment = z.infer<typeof mentionComment>;
|
|||
export type Mention = z.infer<typeof mention>;
|
||||
export type GetCommentsRequest = z.infer<typeof getCommentsRequest>;
|
||||
export type GetCommentsResponse = z.infer<typeof getCommentsResponse>;
|
||||
export type GetBatchCommentsRequest = z.infer<typeof getBatchCommentsRequest>;
|
||||
export type GetBatchCommentsResponse = z.infer<typeof getBatchCommentsResponse>;
|
||||
export type CreateCommentRequest = z.infer<typeof createCommentRequest>;
|
||||
export type CreateCommentResponse = z.infer<typeof createCommentResponse>;
|
||||
export type CreateReplyRequest = z.infer<typeof createReplyRequest>;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { chatCommentsApiService } from "@/lib/apis/chat-comments-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
|
||||
|
|
@ -7,12 +8,109 @@ interface UseCommentsOptions {
|
|||
enabled?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Module-level coordination: when a batch request is in-flight, individual
|
||||
// useComments queryFns piggy-back on it instead of making their own requests.
|
||||
//
|
||||
// _batchReady is a promise that resolves once the batch useEffect has had a
|
||||
// chance to set _batchInflight. Individual queryFns await this gate before
|
||||
// deciding whether to piggy-back or fetch on their own, eliminating the
|
||||
// previous race where setTimeout(0) was not enough.
|
||||
// ---------------------------------------------------------------------------
|
||||
let _batchInflight: Promise<void> | null = null;
|
||||
let _batchTargetIds = new Set<number>();
|
||||
let _batchReady: Promise<void> | null = null;
|
||||
let _resolveBatchReady: (() => void) | null = null;
|
||||
|
||||
function resetBatchGate() {
|
||||
_batchReady = new Promise<void>((r) => {
|
||||
_resolveBatchReady = r;
|
||||
});
|
||||
}
|
||||
|
||||
// Open the initial gate immediately (no batch pending yet)
|
||||
resetBatchGate();
|
||||
_resolveBatchReady?.();
|
||||
|
||||
export function useComments({ messageId, enabled = true }: UseCommentsOptions) {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useQuery({
|
||||
queryKey: cacheKeys.comments.byMessage(messageId),
|
||||
queryFn: async () => {
|
||||
// Wait for the batch gate so the useEffect in useBatchCommentsPreload
|
||||
// has a chance to set _batchInflight before we decide.
|
||||
if (_batchReady) {
|
||||
await _batchReady;
|
||||
}
|
||||
|
||||
if (_batchInflight && _batchTargetIds.has(messageId)) {
|
||||
await _batchInflight;
|
||||
const cached = queryClient.getQueryData(cacheKeys.comments.byMessage(messageId));
|
||||
if (cached) return cached;
|
||||
}
|
||||
|
||||
return chatCommentsApiService.getComments({ message_id: messageId });
|
||||
},
|
||||
enabled: enabled && !!messageId,
|
||||
staleTime: 30_000,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch-fetch comments for all given message IDs in a single request, then
|
||||
* seed the per-message React Query cache so individual useComments hooks
|
||||
* resolve from cache instead of firing their own requests.
|
||||
*/
|
||||
export function useBatchCommentsPreload(messageIds: number[]) {
|
||||
const queryClient = useQueryClient();
|
||||
const prevKeyRef = useRef<string>("");
|
||||
|
||||
useEffect(() => {
|
||||
if (!messageIds.length) return;
|
||||
|
||||
const key = messageIds
|
||||
.slice()
|
||||
.sort((a, b) => a - b)
|
||||
.join(",");
|
||||
if (key === prevKeyRef.current) return;
|
||||
prevKeyRef.current = key;
|
||||
|
||||
// Open a new gate so individual queryFns wait for us
|
||||
resetBatchGate();
|
||||
|
||||
_batchTargetIds = new Set(messageIds);
|
||||
let cancelled = false;
|
||||
|
||||
const promise = chatCommentsApiService
|
||||
.getBatchComments({ message_ids: messageIds })
|
||||
.then((data) => {
|
||||
if (cancelled) return;
|
||||
for (const [msgIdStr, commentList] of Object.entries(data.comments_by_message)) {
|
||||
queryClient.setQueryData(cacheKeys.comments.byMessage(Number(msgIdStr)), commentList);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
// Batch failed; individual queryFns will fall through to their own fetch
|
||||
})
|
||||
.finally(() => {
|
||||
if (_batchInflight === promise) {
|
||||
_batchInflight = null;
|
||||
_batchTargetIds = new Set();
|
||||
}
|
||||
});
|
||||
|
||||
_batchInflight = promise;
|
||||
|
||||
// Release the gate — individual queryFns can now check _batchInflight
|
||||
_resolveBatchReady?.();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (_batchInflight === promise) {
|
||||
_batchInflight = null;
|
||||
_batchTargetIds = new Set();
|
||||
}
|
||||
};
|
||||
}, [messageIds, queryClient]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||
|
|
@ -183,56 +184,47 @@ export function useDocuments(
|
|||
[]
|
||||
);
|
||||
|
||||
// EFFECT 1: Load ALL documents from API (PRIMARY source of truth)
|
||||
// No type filter — always fetches everything so typeCounts stay complete
|
||||
// STEP 1: Load ALL documents from API (PRIMARY source of truth).
|
||||
// Uses React Query for automatic deduplication, caching, and staleTime so
|
||||
// multiple components mounting useDocuments(sameId) share a single request.
|
||||
const {
|
||||
data: apiResponse,
|
||||
isLoading: apiLoading,
|
||||
error: apiError,
|
||||
} = useQuery({
|
||||
queryKey: ["documents", "all", searchSpaceId],
|
||||
queryFn: () =>
|
||||
documentsApiService.getDocuments({
|
||||
queryParams: {
|
||||
search_space_id: searchSpaceId!,
|
||||
page: 0,
|
||||
page_size: -1,
|
||||
},
|
||||
}),
|
||||
enabled: !!searchSpaceId,
|
||||
staleTime: 30_000,
|
||||
});
|
||||
|
||||
// Seed local state from API response (runs once per fresh fetch)
|
||||
useEffect(() => {
|
||||
if (!searchSpaceId) {
|
||||
setLoading(false);
|
||||
return;
|
||||
if (!apiResponse) return;
|
||||
populateUserCache(apiResponse.items);
|
||||
const docs = apiResponse.items.map(apiToDisplayDoc);
|
||||
setAllDocuments(docs);
|
||||
apiLoadedRef.current = true;
|
||||
setError(null);
|
||||
}, [apiResponse, populateUserCache, apiToDisplayDoc]);
|
||||
|
||||
// Propagate loading / error from React Query
|
||||
useEffect(() => {
|
||||
setLoading(apiLoading);
|
||||
}, [apiLoading]);
|
||||
|
||||
useEffect(() => {
|
||||
if (apiError) {
|
||||
setError(apiError instanceof Error ? apiError : new Error("Failed to load documents"));
|
||||
}
|
||||
|
||||
// Capture validated value for async closure
|
||||
const spaceId = searchSpaceId;
|
||||
|
||||
let mounted = true;
|
||||
apiLoadedRef.current = false;
|
||||
|
||||
async function loadFromApi() {
|
||||
try {
|
||||
setLoading(true);
|
||||
console.log("[useDocuments] Loading from API (source of truth):", spaceId);
|
||||
|
||||
const response = await documentsApiService.getDocuments({
|
||||
queryParams: {
|
||||
search_space_id: spaceId,
|
||||
page: 0,
|
||||
page_size: -1, // Fetch all documents (unfiltered)
|
||||
},
|
||||
});
|
||||
|
||||
if (!mounted) return;
|
||||
|
||||
populateUserCache(response.items);
|
||||
const docs = response.items.map(apiToDisplayDoc);
|
||||
setAllDocuments(docs);
|
||||
apiLoadedRef.current = true;
|
||||
setError(null);
|
||||
console.log("[useDocuments] API loaded", docs.length, "documents");
|
||||
} catch (err) {
|
||||
if (!mounted) return;
|
||||
console.error("[useDocuments] API load failed:", err);
|
||||
setError(err instanceof Error ? err : new Error("Failed to load documents"));
|
||||
} finally {
|
||||
if (mounted) setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
loadFromApi();
|
||||
|
||||
return () => {
|
||||
mounted = false;
|
||||
};
|
||||
}, [searchSpaceId, populateUserCache, apiToDisplayDoc]);
|
||||
}, [apiError]);
|
||||
|
||||
// EFFECT 2: Start Electric sync + live query for real-time updates
|
||||
// No type filter — syncs and queries ALL documents; filtering is client-side
|
||||
|
|
|
|||
|
|
@ -8,8 +8,11 @@ import {
|
|||
type DeleteCommentRequest,
|
||||
deleteCommentRequest,
|
||||
deleteCommentResponse,
|
||||
type GetBatchCommentsRequest,
|
||||
type GetCommentsRequest,
|
||||
type GetMentionsRequest,
|
||||
getBatchCommentsRequest,
|
||||
getBatchCommentsResponse,
|
||||
getCommentsRequest,
|
||||
getCommentsResponse,
|
||||
getMentionsRequest,
|
||||
|
|
@ -22,6 +25,22 @@ import { ValidationError } from "@/lib/error";
|
|||
import { baseApiService } from "./base-api.service";
|
||||
|
||||
class ChatCommentsApiService {
|
||||
/**
|
||||
* Batch-fetch comments for multiple messages in one request
|
||||
*/
|
||||
getBatchComments = async (request: GetBatchCommentsRequest) => {
|
||||
const parsed = getBatchCommentsRequest.safeParse(request);
|
||||
|
||||
if (!parsed.success) {
|
||||
const errorMessage = parsed.error.issues.map((issue) => issue.message).join(", ");
|
||||
throw new ValidationError(`Invalid request: ${errorMessage}`);
|
||||
}
|
||||
|
||||
return baseApiService.post("/api/v1/messages/comments/batch", getBatchCommentsResponse, {
|
||||
body: { message_ids: parsed.data.message_ids },
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Get comments for a message
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -109,7 +109,9 @@ class DocumentsApiService {
|
|||
};
|
||||
|
||||
/**
|
||||
* Upload document files
|
||||
* Upload document files in batches to avoid proxy/LB timeouts.
|
||||
* Files are split into chunks of UPLOAD_BATCH_SIZE and sent as separate
|
||||
* requests. Results are aggregated into a single response.
|
||||
*/
|
||||
uploadDocument = async (request: UploadDocumentRequest) => {
|
||||
const parsedRequest = uploadDocumentRequest.safeParse(request);
|
||||
|
|
@ -121,17 +123,54 @@ class DocumentsApiService {
|
|||
throw new ValidationError(`Invalid request: ${errorMessage}`);
|
||||
}
|
||||
|
||||
// Create FormData for file upload
|
||||
const formData = new FormData();
|
||||
parsedRequest.data.files.forEach((file) => {
|
||||
formData.append("files", file);
|
||||
});
|
||||
formData.append("search_space_id", String(parsedRequest.data.search_space_id));
|
||||
formData.append("should_summarize", String(parsedRequest.data.should_summarize));
|
||||
const { files, search_space_id, should_summarize } = parsedRequest.data;
|
||||
const UPLOAD_BATCH_SIZE = 5;
|
||||
|
||||
return baseApiService.postFormData(`/api/v1/documents/fileupload`, uploadDocumentResponse, {
|
||||
body: formData,
|
||||
});
|
||||
const batches: File[][] = [];
|
||||
for (let i = 0; i < files.length; i += UPLOAD_BATCH_SIZE) {
|
||||
batches.push(files.slice(i, i + UPLOAD_BATCH_SIZE));
|
||||
}
|
||||
|
||||
const allDocumentIds: number[] = [];
|
||||
const allDuplicateIds: number[] = [];
|
||||
let totalFiles = 0;
|
||||
let pendingFiles = 0;
|
||||
let skippedDuplicates = 0;
|
||||
|
||||
for (const batch of batches) {
|
||||
const formData = new FormData();
|
||||
batch.forEach((file) => formData.append("files", file));
|
||||
formData.append("search_space_id", String(search_space_id));
|
||||
formData.append("should_summarize", String(should_summarize));
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), 120_000);
|
||||
|
||||
try {
|
||||
const result = await baseApiService.postFormData(
|
||||
`/api/v1/documents/fileupload`,
|
||||
uploadDocumentResponse,
|
||||
{ body: formData, signal: controller.signal }
|
||||
);
|
||||
|
||||
allDocumentIds.push(...(result.document_ids ?? []));
|
||||
allDuplicateIds.push(...(result.duplicate_document_ids ?? []));
|
||||
totalFiles += result.total_files ?? batch.length;
|
||||
pendingFiles += result.pending_files ?? 0;
|
||||
skippedDuplicates += result.skipped_duplicates ?? 0;
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
message: "Files uploaded for processing" as const,
|
||||
document_ids: allDocumentIds,
|
||||
duplicate_document_ids: allDuplicateIds,
|
||||
total_files: totalFiles,
|
||||
pending_files: pendingFiles,
|
||||
skipped_duplicates: skippedDuplicates,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,3 +1,10 @@
|
|||
import { QueryClient } from "@tanstack/react-query";
|
||||
|
||||
export const queryClient = new QueryClient();
|
||||
export const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
staleTime: 30_000,
|
||||
refetchOnWindowFocus: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
import { loader } from "fumadocs-core/source";
|
||||
import { docs } from "@/.source/server";
|
||||
import { icons } from "lucide-react";
|
||||
import { createElement } from "react";
|
||||
import { docs } from "@/.source/server";
|
||||
|
||||
export const source = loader({
|
||||
baseUrl: "/docs",
|
||||
source: docs.toFumadocsSource(),
|
||||
icon(icon) {
|
||||
if (icon && icon in icons)
|
||||
return createElement(icons[icon as keyof typeof icons]);
|
||||
if (icon && icon in icons) return createElement(icons[icon as keyof typeof icons]);
|
||||
},
|
||||
});
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue