This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-03-27 17:08:08 -07:00
commit 947def5c4a
65 changed files with 10278 additions and 7590 deletions

View file

@ -14,6 +14,20 @@ from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_date_only(value: str) -> bool:
"""Return True when *value* looks like a bare date (YYYY-MM-DD) with no time component."""
return len(value) <= 10 and "T" not in value
def _build_time_body(value: str, context: dict[str, Any] | Any) -> dict[str, str]:
"""Build a Google Calendar start/end body using ``date`` for all-day
events and ``dateTime`` for timed events."""
if _is_date_only(value):
return {"date": value}
tz = context.get("timezone", "UTC") if isinstance(context, dict) else "UTC"
return {"dateTime": value, "timeZone": tz}
def create_update_calendar_event_tool( def create_update_calendar_event_tool(
db_session: AsyncSession | None = None, db_session: AsyncSession | None = None,
search_space_id: int | None = None, search_space_id: int | None = None,
@ -255,25 +269,13 @@ def create_update_calendar_event_tool(
if final_new_summary is not None: if final_new_summary is not None:
update_body["summary"] = final_new_summary update_body["summary"] = final_new_summary
if final_new_start_datetime is not None: if final_new_start_datetime is not None:
tz = ( update_body["start"] = _build_time_body(
context.get("timezone", "UTC") final_new_start_datetime, context
if isinstance(context, dict)
else "UTC"
) )
update_body["start"] = {
"dateTime": final_new_start_datetime,
"timeZone": tz,
}
if final_new_end_datetime is not None: if final_new_end_datetime is not None:
tz = ( update_body["end"] = _build_time_body(
context.get("timezone", "UTC") final_new_end_datetime, context
if isinstance(context, dict)
else "UTC"
) )
update_body["end"] = {
"dateTime": final_new_end_datetime,
"timeZone": tz,
}
if final_new_description is not None: if final_new_description is not None:
update_body["description"] = final_new_description update_body["description"] = final_new_description
if final_new_location is not None: if final_new_location is not None:

View file

@ -2,13 +2,14 @@
from .change_tracker import categorize_change, fetch_all_changes, get_start_page_token from .change_tracker import categorize_change, fetch_all_changes, get_start_page_token
from .client import GoogleDriveClient from .client import GoogleDriveClient
from .content_extractor import download_and_process_file from .content_extractor import download_and_extract_content, download_and_process_file
from .credentials import get_valid_credentials, validate_credentials from .credentials import get_valid_credentials, validate_credentials
from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents
__all__ = [ __all__ = [
"GoogleDriveClient", "GoogleDriveClient",
"categorize_change", "categorize_change",
"download_and_extract_content",
"download_and_process_file", "download_and_process_file",
"fetch_all_changes", "fetch_all_changes",
"get_file_by_id", "get_file_by_id",

View file

@ -84,22 +84,50 @@ async def get_changes(
return [], None, f"Error getting changes: {e!s}" return [], None, f"Error getting changes: {e!s}"
async def _is_descendant_of(
client: GoogleDriveClient,
parent_ids: list[str],
target_folder_id: str,
max_depth: int = 20,
) -> bool:
"""Walk up the parent chain to check if any ancestor is *target_folder_id*."""
visited: set[str] = set()
to_check = list(parent_ids)
for _ in range(max_depth):
if not to_check:
return False
current = to_check.pop(0)
if current in visited:
continue
visited.add(current)
if current == target_folder_id:
return True
try:
service = await client.get_service()
meta = (
service.files()
.get(fileId=current, fields="parents", supportsAllDrives=True)
.execute()
)
grandparents = meta.get("parents", [])
to_check.extend(grandparents)
except Exception:
continue
return False
async def _filter_changes_by_folder( async def _filter_changes_by_folder(
client: GoogleDriveClient, client: GoogleDriveClient,
changes: list[dict[str, Any]], changes: list[dict[str, Any]],
folder_id: str, folder_id: str,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """Filter changes to only include files within the specified folder
Filter changes to only include files within the specified folder. (direct children or nested descendants)."""
Args:
client: GoogleDriveClient instance
changes: List of changes from API
folder_id: Folder ID to filter by
Returns:
Filtered list of changes
"""
filtered = [] filtered = []
for change in changes: for change in changes:
@ -108,14 +136,10 @@ async def _filter_changes_by_folder(
filtered.append(change) filtered.append(change)
continue continue
# Check if file is in the folder (or subfolder)
parents = file.get("parents", []) parents = file.get("parents", [])
if folder_id in parents: if folder_id in parents:
filtered.append(change) filtered.append(change)
else: elif await _is_descendant_of(client, parents, folder_id):
# Check if any parent is a descendant of folder_id
# This is a simplified check - full implementation would traverse hierarchy
# For now, we'll include it and let indexer validate
filtered.append(change) filtered.append(change)
return filtered return filtered

View file

@ -1,9 +1,15 @@
"""Google Drive API client.""" """Google Drive API client."""
import asyncio
import io import io
import logging
import threading
import time
from typing import Any from typing import Any
import httplib2
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from google_auth_httplib2 import AuthorizedHttp
from googleapiclient.discovery import build from googleapiclient.discovery import build
from googleapiclient.errors import HttpError from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseUpload from googleapiclient.http import MediaIoBaseUpload
@ -12,6 +18,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
from .credentials import get_valid_credentials from .credentials import get_valid_credentials
from .file_types import GOOGLE_DOC, GOOGLE_SHEET from .file_types import GOOGLE_DOC, GOOGLE_SHEET
logger = logging.getLogger(__name__)
def _build_thread_http(credentials: Credentials) -> AuthorizedHttp:
"""Create a per-thread HTTP transport so concurrent downloads don't share
the same ``httplib2.Http`` (which is not thread-safe)."""
return AuthorizedHttp(credentials, http=httplib2.Http())
class GoogleDriveClient: class GoogleDriveClient:
"""Client for Google Drive API operations.""" """Client for Google Drive API operations."""
@ -34,7 +48,9 @@ class GoogleDriveClient:
self.session = session self.session = session
self.connector_id = connector_id self.connector_id = connector_id
self._credentials = credentials self._credentials = credentials
self._resolved_credentials: Credentials | None = None
self.service = None self.service = None
self._service_lock = asyncio.Lock()
async def get_service(self): async def get_service(self):
""" """
@ -49,6 +65,10 @@ class GoogleDriveClient:
if self.service: if self.service:
return self.service return self.service
async with self._service_lock:
if self.service:
return self.service
try: try:
if self._credentials: if self._credentials:
credentials = self._credentials credentials = self._credentials
@ -56,6 +76,7 @@ class GoogleDriveClient:
credentials = await get_valid_credentials( credentials = await get_valid_credentials(
self.session, self.connector_id self.session, self.connector_id
) )
self._resolved_credentials = credentials
self.service = build("drive", "v3", credentials=credentials) self.service = build("drive", "v3", credentials=credentials)
return self.service return self.service
except Exception as e: except Exception as e:
@ -134,6 +155,33 @@ class GoogleDriveClient:
except Exception as e: except Exception as e:
return None, f"Error getting file metadata: {e!s}" return None, f"Error getting file metadata: {e!s}"
@staticmethod
def _sync_download_file(
service, file_id: str, credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking download — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[download] START file={file_id} thread={thread}")
try:
from googleapiclient.http import MediaIoBaseDownload
http = _build_thread_http(credentials)
request = service.files().get_media(fileId=file_id)
request.http = http
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while not done:
_, done = downloader.next_chunk()
return fh.getvalue(), None
except HttpError as e:
return None, f"HTTP error downloading file: {e.resp.status}"
except Exception as e:
return None, f"Error downloading file: {e!s}"
finally:
logger.info(f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]: async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
""" """
Download binary file content. Download binary file content.
@ -144,27 +192,76 @@ class GoogleDriveClient:
Returns: Returns:
Tuple of (file content bytes, error message) Tuple of (file content bytes, error message)
""" """
try:
service = await self.get_service() service = await self.get_service()
request = service.files().get_media(fileId=file_id) return await asyncio.to_thread(
self._sync_download_file, service, file_id, self._resolved_credentials,
)
import io @staticmethod
def _sync_download_file_to_disk(
fh = io.BytesIO() service, file_id: str, dest_path: str, chunksize: int,
credentials: Credentials,
) -> str | None:
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[download-to-disk] START file={file_id} thread={thread}")
try:
from googleapiclient.http import MediaIoBaseDownload from googleapiclient.http import MediaIoBaseDownload
downloader = MediaIoBaseDownload(fh, request) http = _build_thread_http(credentials)
request = service.files().get_media(fileId=file_id)
request.http = http
with open(dest_path, "wb") as fh:
downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize)
done = False done = False
while not done: while not done:
_, done = downloader.next_chunk() _, done = downloader.next_chunk()
return None
return fh.getvalue(), None
except HttpError as e: except HttpError as e:
return None, f"HTTP error downloading file: {e.resp.status}" return f"HTTP error downloading file: {e.resp.status}"
except Exception as e: except Exception as e:
return None, f"Error downloading file: {e!s}" return f"Error downloading file: {e!s}"
finally:
logger.info(f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
async def download_file_to_disk(
self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024,
) -> str | None:
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
Returns error message on failure, None on success.
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_download_file_to_disk,
service, file_id, dest_path, chunksize, self._resolved_credentials,
)
@staticmethod
def _sync_export_google_file(
service, file_id: str, mime_type: str, credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking export — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[export] START file={file_id} thread={thread}")
try:
http = _build_thread_http(credentials)
content = (
service.files()
.export(fileId=file_id, mimeType=mime_type)
.execute(http=http)
)
if not isinstance(content, bytes):
content = content.encode("utf-8")
return content, None
except HttpError as e:
return None, f"HTTP error exporting file: {e.resp.status}"
except Exception as e:
return None, f"Error exporting file: {e!s}"
finally:
logger.info(f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
async def export_google_file( async def export_google_file(
self, file_id: str, mime_type: str self, file_id: str, mime_type: str
@ -179,24 +276,12 @@ class GoogleDriveClient:
Returns: Returns:
Tuple of (exported content as bytes, error message) Tuple of (exported content as bytes, error message)
""" """
try:
service = await self.get_service() service = await self.get_service()
content = ( return await asyncio.to_thread(
service.files().export(fileId=file_id, mimeType=mime_type).execute() self._sync_export_google_file, service, file_id, mime_type,
self._resolved_credentials,
) )
# Content is already bytes from the API
# Keep as bytes to support both text and binary formats (like PDF)
if not isinstance(content, bytes):
content = content.encode("utf-8")
return content, None
except HttpError as e:
return None, f"HTTP error exporting file: {e.resp.status}"
except Exception as e:
return None, f"Error exporting file: {e!s}"
async def create_file( async def create_file(
self, self,
name: str, name: str,

View file

@ -1,8 +1,11 @@
"""Content extraction for Google Drive files.""" """Content extraction for Google Drive files."""
import asyncio
import logging import logging
import os import os
import tempfile import tempfile
import threading
import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -12,11 +15,182 @@ from app.db import Log
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from .client import GoogleDriveClient from .client import GoogleDriveClient
from .file_types import get_export_mime_type, is_google_workspace_file, should_skip_file from .file_types import (
get_export_mime_type,
get_extension_from_mime,
is_google_workspace_file,
should_skip_file,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def download_and_extract_content(
client: GoogleDriveClient,
file: dict[str, Any],
) -> tuple[str | None, dict[str, Any], str | None]:
"""Download a Google Drive file and extract its content as markdown.
ETL only -- no DB writes, no indexing, no summarization.
Returns:
(markdown_content, drive_metadata, error_message)
On success error_message is None.
"""
file_id = file.get("id")
file_name = file.get("name", "Unknown")
mime_type = file.get("mimeType", "")
if should_skip_file(mime_type):
return None, {}, f"Skipping {mime_type}"
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
drive_metadata: dict[str, Any] = {
"google_drive_file_id": file_id,
"google_drive_file_name": file_name,
"google_drive_mime_type": mime_type,
"source_connector": "google_drive",
}
if "modifiedTime" in file:
drive_metadata["modified_time"] = file["modifiedTime"]
if "createdTime" in file:
drive_metadata["created_time"] = file["createdTime"]
if "size" in file:
drive_metadata["file_size"] = file["size"]
if "webViewLink" in file:
drive_metadata["web_view_link"] = file["webViewLink"]
if "md5Checksum" in file:
drive_metadata["md5_checksum"] = file["md5Checksum"]
if is_google_workspace_file(mime_type):
export_ext = get_extension_from_mime(get_export_mime_type(mime_type) or "")
drive_metadata["exported_as"] = export_ext.lstrip(".") if export_ext else "pdf"
drive_metadata["original_workspace_type"] = mime_type.split(".")[-1]
temp_file_path = None
try:
if is_google_workspace_file(mime_type):
export_mime = get_export_mime_type(mime_type)
if not export_mime:
return None, drive_metadata, f"Cannot export Google Workspace type: {mime_type}"
content_bytes, error = await client.export_google_file(file_id, export_mime)
if error:
return None, drive_metadata, error
extension = get_extension_from_mime(export_mime) or ".pdf"
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
tmp.write(content_bytes)
temp_file_path = tmp.name
else:
extension = (
Path(file_name).suffix
or get_extension_from_mime(mime_type)
or ".bin"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
temp_file_path = tmp.name
error = await client.download_file_to_disk(file_id, temp_file_path)
if error:
return None, drive_metadata, error
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
return markdown, drive_metadata, None
except Exception as e:
logger.warning(f"Failed to extract content from {file_name}: {e!s}")
return None, drive_metadata, str(e)
finally:
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except Exception:
pass
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
"""Parse a local file to markdown using the configured ETL service."""
lower = filename.lower()
if lower.endswith((".md", ".markdown", ".txt")):
with open(file_path, encoding="utf-8") as f:
return f.read()
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
from app.config import config as app_config
from litellm import atranscription
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
t0 = time.monotonic()
logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}")
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
text = result.get("text", "")
else:
with open(file_path, "rb") as audio_file:
kwargs: dict[str, Any] = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
resp = await atranscription(**kwargs)
text = resp.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
return f"# Transcription of {filename}\n\n{text}"
# Document files -- use configured ETL service
from app.config import config as app_config
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
file_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
return await convert_document_to_markdown(docs)
if app_config.ETL_SERVICE == "LLAMACLOUD":
from app.tasks.document_processors.file_processors import (
parse_with_llamacloud_retry,
)
result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50)
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
if not markdown_documents:
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
return markdown_documents[0].text
if app_config.ETL_SERVICE == "DOCLING":
from docling.document_converter import DocumentConverter
converter = DocumentConverter()
t0 = time.monotonic()
logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}")
result = await asyncio.to_thread(converter.convert, file_path)
logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
return result.document.export_to_markdown()
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
async def download_and_process_file( async def download_and_process_file(
client: GoogleDriveClient, client: GoogleDriveClient,
file: dict[str, Any], file: dict[str, Any],
@ -68,14 +242,17 @@ async def download_and_process_file(
if error: if error:
return None, error return None, error
extension = ".pdf" if export_mime == "application/pdf" else ".txt" extension = get_extension_from_mime(export_mime) or ".pdf"
else: else:
content_bytes, error = await client.download_file(file_id) content_bytes, error = await client.download_file(file_id)
if error: if error:
return None, error return None, error
# Preserve original file extension extension = (
extension = Path(file_name).suffix or ".bin" Path(file_name).suffix
or get_extension_from_mime(mime_type)
or ".bin"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file: with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file:
tmp_file.write(content_bytes) tmp_file.write(content_bytes)
@ -113,7 +290,12 @@ async def download_and_process_file(
connector_info["metadata"]["md5_checksum"] = file["md5Checksum"] connector_info["metadata"]["md5_checksum"] = file["md5Checksum"]
if is_google_workspace_file(mime_type): if is_google_workspace_file(mime_type):
connector_info["metadata"]["exported_as"] = "pdf" export_ext = get_extension_from_mime(
get_export_mime_type(mime_type) or ""
)
connector_info["metadata"]["exported_as"] = (
export_ext.lstrip(".") if export_ext else "pdf"
)
connector_info["metadata"]["original_workspace_type"] = mime_type.split( connector_info["metadata"]["original_workspace_type"] = mime_type.split(
"." "."
)[-1] )[-1]

View file

@ -7,11 +7,34 @@ GOOGLE_FOLDER = "application/vnd.google-apps.folder"
GOOGLE_SHORTCUT = "application/vnd.google-apps.shortcut" GOOGLE_SHORTCUT = "application/vnd.google-apps.shortcut"
EXPORT_FORMATS = { EXPORT_FORMATS = {
GOOGLE_DOC: "application/pdf", GOOGLE_DOC: "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
GOOGLE_SHEET: "application/pdf", GOOGLE_SHEET: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
GOOGLE_SLIDE: "application/pdf", GOOGLE_SLIDE: "application/pdf",
} }
MIME_TO_EXTENSION: dict[str, str] = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
"application/vnd.ms-excel": ".xls",
"application/msword": ".doc",
"application/vnd.ms-powerpoint": ".ppt",
"text/plain": ".txt",
"text/csv": ".csv",
"text/html": ".html",
"text/markdown": ".md",
"application/json": ".json",
"application/xml": ".xml",
"image/png": ".png",
"image/jpeg": ".jpg",
}
def get_extension_from_mime(mime_type: str) -> str | None:
"""Return a file extension (with leading dot) for a MIME type, or None."""
return MIME_TO_EXTENSION.get(mime_type)
def is_google_workspace_file(mime_type: str) -> bool: def is_google_workspace_file(mime_type: str) -> bool:
"""Check if file is a Google Workspace file that needs export.""" """Check if file is a Google Workspace file that needs export."""

View file

@ -3,10 +3,17 @@ import hashlib
from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.connector_document import ConnectorDocument
def compute_identifier_hash(
document_type_value: str, unique_id: str, search_space_id: int
) -> str:
"""Return a stable SHA-256 hash from raw identity components."""
combined = f"{document_type_value}:{unique_id}:{search_space_id}"
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str: def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
"""Return a stable SHA-256 hash identifying a document by its source identity.""" """Return a stable SHA-256 hash identifying a document by its source identity."""
combined = f"{doc.document_type.value}:{doc.unique_id}:{doc.search_space_id}" return compute_identifier_hash(doc.document_type.value, doc.unique_id, doc.search_space_id)
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
def compute_content_hash(doc: ConnectorDocument) -> str: def compute_content_hash(doc: ConnectorDocument) -> str:

View file

@ -1,17 +1,21 @@
import asyncio
import contextlib import contextlib
import logging
import time import time
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime from datetime import UTC, datetime
from sqlalchemy import delete, select from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Chunk, Document, DocumentStatus from app.db import NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, DocumentStatus
from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_texts from app.indexing_pipeline.document_embedder import embed_texts
from app.indexing_pipeline.document_hashing import ( from app.indexing_pipeline.document_hashing import (
compute_content_hash, compute_content_hash,
compute_identifier_hash,
compute_unique_identifier_hash, compute_unique_identifier_hash,
) )
from app.indexing_pipeline.document_persistence import ( from app.indexing_pipeline.document_persistence import (
@ -54,6 +58,62 @@ class IndexingPipelineService:
def __init__(self, session: AsyncSession) -> None: def __init__(self, session: AsyncSession) -> None:
self.session = session self.session = session
async def migrate_legacy_docs(
self, connector_docs: list[ConnectorDocument]
) -> None:
"""Migrate legacy Composio documents to their native Google type.
For each ConnectorDocument whose document_type has a Composio equivalent
in NATIVE_TO_LEGACY_DOCTYPE, look up the old document by legacy hash and
update its unique_identifier_hash and document_type so that
prepare_for_indexing() can find it under the native hash.
"""
for doc in connector_docs:
legacy_type = NATIVE_TO_LEGACY_DOCTYPE.get(doc.document_type.value)
if not legacy_type:
continue
legacy_hash = compute_identifier_hash(
legacy_type, doc.unique_id, doc.search_space_id
)
result = await self.session.execute(
select(Document).filter(
Document.unique_identifier_hash == legacy_hash
)
)
existing = result.scalars().first()
if existing is None:
continue
native_hash = compute_identifier_hash(
doc.document_type.value, doc.unique_id, doc.search_space_id
)
existing.unique_identifier_hash = native_hash
existing.document_type = doc.document_type
await self.session.commit()
async def index_batch(
self, connector_docs: list[ConnectorDocument], llm
) -> list[Document]:
"""Convenience method: prepare_for_indexing then index each document.
Indexers that need heartbeat callbacks or custom per-document logic
should call prepare_for_indexing() + index() directly instead.
"""
doc_map = {
compute_unique_identifier_hash(cd): cd for cd in connector_docs
}
documents = await self.prepare_for_indexing(connector_docs)
results: list[Document] = []
for document in documents:
connector_doc = doc_map.get(document.unique_identifier_hash)
if connector_doc is None:
continue
result = await self.index(document, connector_doc, llm)
results.append(result)
return results
async def prepare_for_indexing( async def prepare_for_indexing(
self, connector_docs: list[ConnectorDocument] self, connector_docs: list[ConnectorDocument]
) -> list[Document]: ) -> list[Document]:
@ -200,13 +260,14 @@ class IndexingPipelineService:
) )
t_step = time.perf_counter() t_step = time.perf_counter()
chunk_texts = chunk_text( chunk_texts = await asyncio.to_thread(
chunk_text,
connector_doc.source_markdown, connector_doc.source_markdown,
use_code_chunker=connector_doc.should_use_code_chunker, use_code_chunker=connector_doc.should_use_code_chunker,
) )
texts_to_embed = [content, *chunk_texts] texts_to_embed = [content, *chunk_texts]
embeddings = embed_texts(texts_to_embed) embeddings = await asyncio.to_thread(embed_texts, texts_to_embed)
summary_embedding, *chunk_embeddings = embeddings summary_embedding, *chunk_embeddings = embeddings
chunks = [ chunks = [
@ -268,3 +329,126 @@ class IndexingPipelineService:
await self.session.refresh(document) await self.session.refresh(document)
return document return document
async def index_batch_parallel(
self,
connector_docs: list[ConnectorDocument],
get_llm: Callable[[AsyncSession], Awaitable],
*,
max_concurrency: int = 4,
on_heartbeat: Callable[[int], Awaitable[None]] | None = None,
heartbeat_interval: float = 30.0,
) -> tuple[list[Document], int, int]:
"""Index documents in parallel with bounded concurrency.
Phase 1 (serial): prepare_for_indexing using self.session.
Phase 2 (parallel): index each document in an isolated session,
bounded by a semaphore to avoid overwhelming APIs/DB.
"""
logger = logging.getLogger(__name__)
perf = get_perf_logger()
t_total = time.perf_counter()
doc_map = {
compute_unique_identifier_hash(cd): cd for cd in connector_docs
}
documents = await self.prepare_for_indexing(connector_docs)
if not documents:
return [], 0, 0
from app.tasks.celery_tasks import get_celery_session_maker
sem = asyncio.Semaphore(max_concurrency)
lock = asyncio.Lock()
indexed_count = 0
failed_count = 0
results: list[Document] = []
last_heartbeat = time.time()
async def _index_one(document: Document) -> Document | Exception:
nonlocal indexed_count, failed_count, last_heartbeat
connector_doc = doc_map.get(document.unique_identifier_hash)
if connector_doc is None:
logger.warning(
"No matching ConnectorDocument for document %s, skipping",
document.id,
)
async with lock:
failed_count += 1
return document
async with sem:
session_maker = get_celery_session_maker()
async with session_maker() as isolated_session:
try:
refetched = await isolated_session.get(
Document, document.id
)
if refetched is None:
async with lock:
failed_count += 1
return document
llm = await get_llm(isolated_session)
iso_pipeline = IndexingPipelineService(isolated_session)
result = await iso_pipeline.index(
refetched, connector_doc, llm
)
async with lock:
if DocumentStatus.is_state(
result.status, DocumentStatus.READY
):
indexed_count += 1
else:
failed_count += 1
if on_heartbeat:
now = time.time()
if now - last_heartbeat >= heartbeat_interval:
await on_heartbeat(indexed_count)
last_heartbeat = now
return result
except Exception as exc:
logger.error(
"Parallel index failed for doc %s: %s",
document.id,
exc,
exc_info=True,
)
async with lock:
failed_count += 1
return exc
tasks = [_index_one(doc) for doc in documents]
t_parallel = time.perf_counter()
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
perf.info(
"[indexing] index_batch_parallel gather docs=%d concurrency=%d "
"indexed=%d failed=%d in %.3fs",
len(documents),
max_concurrency,
indexed_count,
failed_count,
time.perf_counter() - t_parallel,
)
for outcome in outcomes:
if isinstance(outcome, Document):
results.append(outcome)
elif isinstance(outcome, Exception):
pass
perf.info(
"[indexing] index_batch_parallel TOTAL input=%d prepared=%d "
"indexed=%d failed=%d in %.3fs",
len(connector_docs),
len(documents),
indexed_count,
failed_count,
time.perf_counter() - t_total,
)
return results, indexed_count, failed_count

View file

@ -1,19 +1,43 @@
""" """
Editor routes for document editing with markdown (Plate.js frontend). Editor routes for document editing with markdown (Plate.js frontend).
Includes multi-format export (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text).
""" """
import asyncio
import io
import logging
import os
import re
import tempfile
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, HTTPException import pypandoc
import typst
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.db import Document, DocumentType, Permission, User, get_async_session from app.db import Document, DocumentType, Permission, User, get_async_session
from app.routes.reports_routes import (
ExportFormat,
_FILE_EXTENSIONS,
_MEDIA_TYPES,
_normalize_latex_delimiters,
_strip_wrapping_code_fences,
)
from app.templates.export_helpers import (
get_html_css_path,
get_reference_docx_path,
get_typst_template_path,
)
from app.users import current_active_user from app.users import current_active_user
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -212,3 +236,153 @@ async def save_document(
"message": "Document saved and will be reindexed in the background", "message": "Document saved and will be reindexed in the background",
"updated_at": document.updated_at.isoformat(), "updated_at": document.updated_at.isoformat(),
} }
@router.get(
"/search-spaces/{search_space_id}/documents/{document_id}/export"
)
async def export_document(
search_space_id: int,
document_id: int,
format: ExportFormat = Query(
ExportFormat.PDF,
description="Export format: pdf, docx, html, latex, epub, odt, or plain",
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Export a document in the requested format (reuses the report export pipeline)."""
await check_permission(
session,
user,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
)
result = await session.execute(
select(Document)
.options(selectinload(Document.chunks))
.filter(
Document.id == document_id,
Document.search_space_id == search_space_id,
)
)
document = result.scalars().first()
if not document:
raise HTTPException(status_code=404, detail="Document not found")
# Resolve markdown content (same priority as editor-content endpoint)
markdown_content: str | None = document.source_markdown
if markdown_content is None and document.blocknote_document:
from app.utils.blocknote_to_markdown import blocknote_to_markdown
markdown_content = blocknote_to_markdown(document.blocknote_document)
if markdown_content is None:
chunks = sorted(document.chunks, key=lambda c: c.id)
if chunks:
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
if not markdown_content or not markdown_content.strip():
raise HTTPException(
status_code=400, detail="Document has no content to export"
)
markdown_content = _strip_wrapping_code_fences(markdown_content)
markdown_content = _normalize_latex_delimiters(markdown_content)
doc_title = document.title or "Document"
formatted_date = (
document.created_at.strftime("%B %d, %Y") if document.created_at else ""
)
input_fmt = "gfm+tex_math_dollars"
meta_args = ["-M", f"title:{doc_title}", "-M", f"date:{formatted_date}"]
def _convert_and_read() -> bytes:
if format == ExportFormat.PDF:
typst_template = str(get_typst_template_path())
typst_markup: str = pypandoc.convert_text(
markdown_content,
"typst",
format=input_fmt,
extra_args=[
"--standalone",
f"--template={typst_template}",
"-V", "mainfont:Libertinus Serif",
"-V", "codefont:DejaVu Sans Mono",
*meta_args,
],
)
return typst.compile(typst_markup.encode("utf-8"))
if format == ExportFormat.DOCX:
return _pandoc_to_tempfile(
format.value,
["--standalone", f"--reference-doc={get_reference_docx_path()}", *meta_args],
)
if format == ExportFormat.HTML:
html_str: str = pypandoc.convert_text(
markdown_content,
"html5",
format=input_fmt,
extra_args=[
"--standalone", "--embed-resources",
f"--css={get_html_css_path()}",
"--syntax-highlighting=pygments",
*meta_args,
],
)
return html_str.encode("utf-8")
if format == ExportFormat.EPUB:
return _pandoc_to_tempfile("epub3", ["--standalone", *meta_args])
if format == ExportFormat.ODT:
return _pandoc_to_tempfile("odt", ["--standalone", *meta_args])
if format == ExportFormat.LATEX:
tex_str: str = pypandoc.convert_text(
markdown_content, "latex", format=input_fmt,
extra_args=["--standalone", *meta_args],
)
return tex_str.encode("utf-8")
plain_str: str = pypandoc.convert_text(
markdown_content, "plain", format=input_fmt,
extra_args=["--wrap=auto", "--columns=80"],
)
return plain_str.encode("utf-8")
def _pandoc_to_tempfile(output_format: str, extra_args: list[str]) -> bytes:
fd, tmp_path = tempfile.mkstemp(suffix=f".{output_format}")
os.close(fd)
try:
pypandoc.convert_text(
markdown_content, output_format, format=input_fmt,
extra_args=extra_args, outputfile=tmp_path,
)
with open(tmp_path, "rb") as f:
return f.read()
finally:
os.unlink(tmp_path)
try:
loop = asyncio.get_running_loop()
output = await loop.run_in_executor(None, _convert_and_read)
except Exception as e:
logger.exception("Document export failed")
raise HTTPException(status_code=500, detail=f"Export failed: {e!s}") from e
safe_title = (
"".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title)
.strip()[:80]
or "document"
)
ext = _FILE_EXTENSIONS[format]
return StreamingResponse(
io.BytesIO(output),
media_type=_MEDIA_TYPES[format],
headers={"Content-Disposition": f'attachment; filename="{safe_title}.{ext}"'},
)

View file

@ -2329,7 +2329,7 @@ async def run_google_drive_indexing(
try: try:
from app.tasks.connector_indexers.google_drive_indexer import ( from app.tasks.connector_indexers.google_drive_indexer import (
index_google_drive_files, index_google_drive_files,
index_google_drive_single_file, index_google_drive_selected_files,
) )
# Parse the structured data # Parse the structured data
@ -2402,25 +2402,23 @@ async def run_google_drive_indexing(
exc_info=True, exc_info=True,
) )
# Index each individual file # Index all selected files together via the parallel pipeline
for file in items.files: if items.files:
try: try:
indexed_count, error_message = await index_google_drive_single_file( file_tuples = [(f.id, f.name) for f in items.files]
indexed_count, _skipped, file_errors = await index_google_drive_selected_files(
session, session,
connector_id, connector_id,
search_space_id, search_space_id,
user_id, user_id,
file_id=file.id, files=file_tuples,
file_name=file.name,
) )
if error_message:
errors.append(f"File '{file.name}': {error_message}")
else:
total_indexed += indexed_count total_indexed += indexed_count
errors.extend(file_errors)
except Exception as e: except Exception as e:
errors.append(f"File '{file.name}': {e!s}") errors.append(f"File batch indexing: {e!s}")
logger.error( logger.error(
f"Error indexing file {file.name} ({file.id}): {e}", f"Error batch indexing files: {e}",
exc_info=True, exc_info=True,
) )

View file

@ -209,8 +209,8 @@ class GoogleCalendarKBSyncService:
) )
calendar_id = (document.document_metadata or {}).get( calendar_id = (document.document_metadata or {}).get(
"calendar_id", "primary" "calendar_id"
) ) or "primary"
live_event = await loop.run_in_executor( live_event = await loop.run_in_executor(
None, None,
lambda: ( lambda: (

View file

@ -1,49 +1,74 @@
""" """Confluence connector indexer using the unified parallel indexing pipeline."""
Confluence connector indexer.
Provides real-time document status updates during indexing using a two-phase approach:
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
- Phase 2: Process each document one by one (PENDING PROCESSING READY/FAILED)
"""
import contextlib import contextlib
import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.confluence_history import ConfluenceHistoryConnector from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import ( from .base import (
calculate_date_range, calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash, check_duplicate_document_by_hash,
get_connector_by_id, get_connector_by_id,
get_current_timestamp,
logger, logger,
safe_set_chunks,
update_connector_last_indexed, update_connector_last_indexed,
) )
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]] HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30 HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
page: dict,
full_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Confluence page dict to a ConnectorDocument."""
page_id = page.get("id", "")
page_title = page.get("title", "")
space_id = page.get("spaceId", "")
comment_count = len(page.get("comments", []))
metadata = {
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Confluence Page",
"connector_type": "Confluence",
}
fallback_summary = (
f"Confluence Page: {page_title}\n\nSpace ID: {space_id}\n\n{full_content}"
)
return ConnectorDocument(
title=page_title,
source_markdown=full_content,
unique_id=page_id,
document_type=DocumentType.CONFLUENCE_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_confluence_pages( async def index_confluence_pages(
session: AsyncSession, session: AsyncSession,
connector_id: int, connector_id: int,
@ -53,26 +78,9 @@ async def index_confluence_pages(
end_date: str | None = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None, on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]: ) -> tuple[int, int, str | None]:
""" """Index Confluence pages and comments."""
Index Confluence pages and comments.
Args:
session: Database session
connector_id: ID of the Confluence connector
search_space_id: ID of the search space to store documents in
user_id: User ID
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start( log_entry = await task_logger.log_task_start(
task_name="confluence_pages_indexing", task_name="confluence_pages_indexing",
source="connector_indexing_task", source="connector_indexing_task",
@ -86,7 +94,6 @@ async def index_confluence_pages(
) )
try: try:
# Get the connector from the database
connector = await get_connector_by_id( connector = await get_connector_by_id(
session, connector_id, SearchSourceConnectorType.CONFLUENCE_CONNECTOR session, connector_id, SearchSourceConnectorType.CONFLUENCE_CONNECTOR
) )
@ -98,9 +105,8 @@ async def index_confluence_pages(
"Connector not found", "Connector not found",
{"error_type": "ConnectorNotFound"}, {"error_type": "ConnectorNotFound"},
) )
return 0, f"Connector with ID {connector_id} not found" return 0, 0, f"Connector with ID {connector_id} not found"
# Initialize Confluence OAuth client
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Initializing Confluence OAuth client for connector {connector_id}", f"Initializing Confluence OAuth client for connector {connector_id}",
@ -114,7 +120,6 @@ async def index_confluence_pages(
) )
) )
# Calculate date range
start_date_str, end_date_str = calculate_date_range( start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365 connector, start_date, end_date, default_days_back=365
) )
@ -129,19 +134,14 @@ async def index_confluence_pages(
}, },
) )
# Get pages within date range
try: try:
pages, error = await confluence_client.get_pages_by_date_range( pages, error = await confluence_client.get_pages_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True start_date=start_date_str, end_date=end_date_str, include_comments=True
) )
if error: if error:
# Don't treat "No pages found" as an error that should stop indexing
if "No pages found" in error: if "No pages found" in error:
logger.info(f"No Confluence pages found: {error}") logger.info(f"No Confluence pages found: {error}")
logger.info(
"No pages found is not a critical error, continuing with update"
)
if update_last_indexed: if update_last_indexed:
await update_connector_last_indexed( await update_connector_last_indexed(
session, connector, update_last_indexed session, connector, update_last_indexed
@ -156,11 +156,10 @@ async def index_confluence_pages(
f"No Confluence pages found in date range {start_date_str} to {end_date_str}", f"No Confluence pages found in date range {start_date_str} to {end_date_str}",
{"pages_found": 0}, {"pages_found": 0},
) )
# Close client before returning
if confluence_client: if confluence_client:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await confluence_client.close() await confluence_client.close()
return 0, None return 0, 0, None
else: else:
logger.error(f"Failed to get Confluence pages: {error}") logger.error(f"Failed to get Confluence pages: {error}")
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -169,36 +168,35 @@ async def index_confluence_pages(
"API Error", "API Error",
{"error_type": "APIError"}, {"error_type": "APIError"},
) )
# Close client on error
if confluence_client: if confluence_client:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await confluence_client.close() await confluence_client.close()
return 0, f"Failed to get Confluence pages: {error}" return 0, 0, f"Failed to get Confluence pages: {error}"
logger.info(f"Retrieved {len(pages)} pages from Confluence API") logger.info(f"Retrieved {len(pages)} pages from Confluence API")
except Exception as e: except Exception as e:
logger.error(f"Error fetching Confluence pages: {e!s}", exc_info=True) logger.error(f"Error fetching Confluence pages: {e!s}", exc_info=True)
# Close client on error
if confluence_client: if confluence_client:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await confluence_client.close() await confluence_client.close()
return 0, f"Error fetching Confluence pages: {e!s}" return 0, 0, f"Error fetching Confluence pages: {e!s}"
if not pages:
logger.info("No Confluence pages found for the specified date range")
if update_last_indexed:
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await session.commit()
if confluence_client:
with contextlib.suppress(Exception):
await confluence_client.close()
return 0, 0, None
# =======================================================================
# PHASE 1: Analyze all pages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
documents_indexed = 0
documents_skipped = 0 documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0 duplicate_content_count = 0
connector_docs: list[ConnectorDocument] = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
pages_to_process = [] # List of dicts with document and page data
new_documents_created = False
for page in pages: for page in pages:
try: try:
@ -213,12 +211,10 @@ async def index_confluence_pages(
documents_skipped += 1 documents_skipped += 1
continue continue
# Extract page content
page_content = "" page_content = ""
if page.get("body") and page["body"].get("storage"): if page.get("body") and page["body"].get("storage"):
page_content = page["body"]["storage"].get("value", "") page_content = page["body"]["storage"].get("value", "")
# Add comments to content
comments = page.get("comments", []) comments = page.get("comments", [])
comments_content = "" comments_content = ""
if comments: if comments:
@ -235,61 +231,25 @@ async def index_confluence_pages(
comments_content += f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n" comments_content += f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
# Combine page content with comments
full_content = f"# {page_title}\n\n{page_content}{comments_content}" full_content = f"# {page_title}\n\n{page_content}{comments_content}"
if not full_content.strip(): if not page_content.strip() and not comments:
logger.warning(f"Skipping page with no content: {page_title}") logger.warning(f"Skipping page with no content: {page_title}")
documents_skipped += 1 documents_skipped += 1
continue continue
# Generate unique identifier hash for this Confluence page doc = _build_connector_doc(
unique_identifier_hash = generate_unique_identifier_hash( page,
DocumentType.CONFLUENCE_CONNECTOR, page_id, search_space_id full_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
) )
# Generate content hash
content_hash = generate_content_hash(full_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
comment_count = len(comments)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
pages_to_process.append(
{
"document": existing_document,
"is_new": False,
"full_content": full_content,
"page_content": page_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush: with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash( duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash session, compute_content_hash(doc)
) )
if duplicate_by_content: if duplicate_by_content:
@ -302,151 +262,29 @@ async def index_confluence_pages(
documents_skipped += 1 documents_skipped += 1
continue continue
# Create new document with PENDING status (visible in UI immediately) connector_docs.append(doc)
document = Document(
search_space_id=search_space_id,
title=page_title,
document_type=DocumentType.CONFLUENCE_CONNECTOR,
document_metadata={
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
pages_to_process.append(
{
"document": document,
"is_new": True,
"full_content": full_content,
"page_content": page_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
"space_id": space_id,
"comment_count": comment_count,
}
)
except Exception as e: except Exception as e:
logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True) logger.error(f"Error building ConnectorDocument for page: {e!s}", exc_info=True)
documents_failed += 1 documents_skipped += 1
continue continue
# Commit all pending documents - they all appear in UI now pipeline = IndexingPipelineService(session)
if new_documents_created: await pipeline.migrate_legacy_docs(connector_docs)
logger.info(
f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents"
)
await session.commit()
# ======================================================================= async def _get_llm(s: AsyncSession):
# PHASE 2: Process each document one by one return await get_user_long_context_llm(s, user_id, search_space_id)
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(pages_to_process)} documents")
for item in pages_to_process: _, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
# Send heartbeat periodically connector_docs,
if on_heartbeat_callback: _get_llm,
current_time = time.time() max_concurrency=3,
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS: on_heartbeat=on_heartbeat_callback,
await on_heartbeat_callback(documents_indexed) heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
) )
if user_llm and connector.enable_summary:
document_metadata = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"space_id": item["space_id"],
"comment_count": item["comment_count"],
"document_type": "Confluence Page",
"connector_type": "Confluence",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["full_content"], user_llm, document_metadata
)
else:
summary_content = f"Confluence Page: {item['page_title']}\n\nSpace ID: {item['space_id']}\n\n{item['full_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full page content with comments
chunks = await create_document_chunks(item["full_content"])
# Update document to READY with actual content
document.title = item["page_title"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"page_id": item["page_id"],
"page_title": item["page_title"],
"space_id": item["space_id"],
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Confluence pages processed so far"
)
await session.commit()
except Exception as e:
logger.error(
f"Error processing page {item.get('page_title', 'Unknown')}: {e!s}",
exc_info=True,
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue # Skip this page and continue with others
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# This ensures the UI shows "Last indexed" instead of "Never indexed"
await update_connector_last_indexed(session, connector, update_last_indexed) await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit to ensure all documents are persisted (safety net)
logger.info( logger.info(
f"Final commit: Total {documents_indexed} Confluence pages processed" f"Final commit: Total {documents_indexed} Confluence pages processed"
) )
@ -456,7 +294,6 @@ async def index_confluence_pages(
"Successfully committed all Confluence document changes to database" "Successfully committed all Confluence document changes to database"
) )
except Exception as e: except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if ( if (
"duplicate key value violates unique constraint" in str(e).lower() "duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower() or "uniqueviolationerror" in str(e).lower()
@ -467,11 +304,9 @@ async def index_confluence_pages(
f"Rolling back and continuing. Error: {e!s}" f"Rolling back and continuing. Error: {e!s}"
) )
await session.rollback() await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else: else:
raise raise
# Build warning message if there were issues
warning_parts = [] warning_parts = []
if duplicate_content_count > 0: if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate") warning_parts.append(f"{duplicate_content_count} duplicate")
@ -479,7 +314,6 @@ async def index_confluence_pages(
warning_parts.append(f"{documents_failed} failed") warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully completed Confluence indexing for connector {connector_id}", f"Successfully completed Confluence indexing for connector {connector_id}",
@ -490,22 +324,19 @@ async def index_confluence_pages(
"duplicate_content_count": duplicate_content_count, "duplicate_content_count": duplicate_content_count,
}, },
) )
logger.info( logger.info(
f"Confluence indexing completed: {documents_indexed} ready, " f"Confluence indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed " f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)" f"({duplicate_content_count} duplicate content)"
) )
# Close the client connection
if confluence_client: if confluence_client:
await confluence_client.close() await confluence_client.close()
return documents_indexed, warning_message return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error: except SQLAlchemyError as db_error:
await session.rollback() await session.rollback()
# Close client if it exists
if confluence_client: if confluence_client:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await confluence_client.close() await confluence_client.close()
@ -516,10 +347,9 @@ async def index_confluence_pages(
{"error_type": "SQLAlchemyError"}, {"error_type": "SQLAlchemyError"},
) )
logger.error(f"Database error: {db_error!s}", exc_info=True) logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {db_error!s}" return 0, 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
# Close client if it exists
if confluence_client: if confluence_client:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await confluence_client.close() await confluence_client.close()
@ -530,4 +360,4 @@ async def index_confluence_pages(
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index Confluence pages: {e!s}", exc_info=True) logger.error(f"Failed to index Confluence pages: {e!s}", exc_info=True)
return 0, f"Failed to index Confluence pages: {e!s}" return 0, 0, f"Failed to index Confluence pages: {e!s}"

View file

@ -1,12 +1,10 @@
""" """
Google Calendar connector indexer. Google Calendar connector indexer.
Implements 2-phase document status updates for real-time UI feedback: Uses the shared IndexingPipelineService for document deduplication,
- Phase 1: Create all documents with 'pending' status (visible in UI immediately) summarization, chunking, and embedding.
- Phase 2: Process each document: pending processing ready/failed
""" """
import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -15,29 +13,22 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.google_calendar_connector import GoogleCalendarConnector from app.connectors.google_calendar_connector import GoogleCalendarConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.google_credentials import ( from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES, COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials, build_composio_credentials,
) )
from .base import ( from .base import (
check_document_by_unique_identifier,
check_duplicate_document_by_hash, check_duplicate_document_by_hash,
get_connector_by_id, get_connector_by_id,
get_current_timestamp,
logger, logger,
parse_date_flexible, parse_date_flexible,
safe_set_chunks,
update_connector_last_indexed, update_connector_last_indexed,
) )
@ -46,13 +37,60 @@ ACCEPTED_CALENDAR_CONNECTOR_TYPES = {
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
} }
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]] HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30 HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
event: dict,
event_markdown: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Google Calendar API event dict to a ConnectorDocument."""
event_id = event.get("id", "")
event_summary = event.get("summary", "No Title")
calendar_id = event.get("calendarId", "")
start = event.get("start", {})
end = event.get("end", {})
start_time = start.get("dateTime") or start.get("date", "")
end_time = end.get("dateTime") or end.get("date", "")
location = event.get("location", "")
metadata = {
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"connector_id": connector_id,
"document_type": "Google Calendar Event",
"connector_type": "Google Calendar",
}
fallback_summary = (
f"Google Calendar Event: {event_summary}\n\n{event_markdown}"
)
return ConnectorDocument(
title=event_summary,
source_markdown=event_markdown,
unique_id=event_id,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_google_calendar_events( async def index_google_calendar_events(
session: AsyncSession, session: AsyncSession,
connector_id: int, connector_id: int,
@ -82,7 +120,6 @@ async def index_google_calendar_events(
""" """
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start( log_entry = await task_logger.log_task_start(
task_name="google_calendar_events_indexing", task_name="google_calendar_events_indexing",
source="connector_indexing_task", source="connector_indexing_task",
@ -96,7 +133,7 @@ async def index_google_calendar_events(
) )
try: try:
# Accept both native and Composio Calendar connectors # ── Connector lookup ──────────────────────────────────────────
connector = None connector = None
for ct in ACCEPTED_CALENDAR_CONNECTOR_TYPES: for ct in ACCEPTED_CALENDAR_CONNECTOR_TYPES:
connector = await get_connector_by_id(session, connector_id, ct) connector = await get_connector_by_id(session, connector_id, ct)
@ -112,7 +149,7 @@ async def index_google_calendar_events(
) )
return 0, 0, f"Connector with ID {connector_id} not found" return 0, 0, f"Connector with ID {connector_id} not found"
# Build credentials based on connector type # ── Credential building ───────────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id") connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id: if not connected_account_id:
@ -184,6 +221,7 @@ async def index_google_calendar_events(
) )
return 0, 0, "Google Calendar credentials not found in connector config" return 0, 0, "Google Calendar credentials not found in connector config"
# ── Calendar client init ──────────────────────────────────────
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Initializing Google Calendar client for connector {connector_id}", f"Initializing Google Calendar client for connector {connector_id}",
@ -203,36 +241,26 @@ async def index_google_calendar_events(
if end_date == "undefined" or end_date == "": if end_date == "undefined" or end_date == "":
end_date = None end_date = None
# Calculate date range # ── Date range calculation ────────────────────────────────────
# For calendar connectors, allow future dates to index upcoming events
if start_date is None or end_date is None: if start_date is None or end_date is None:
# Fall back to calculating dates based on last_indexed_at
# Default to today (users can manually select future dates if needed)
calculated_end_date = datetime.now() calculated_end_date = datetime.now()
# Use last_indexed_at as start date if available, otherwise use 30 days ago
if connector.last_indexed_at: if connector.last_indexed_at:
# Convert dates to be comparable (both timezone-naive)
last_indexed_naive = ( last_indexed_naive = (
connector.last_indexed_at.replace(tzinfo=None) connector.last_indexed_at.replace(tzinfo=None)
if connector.last_indexed_at.tzinfo if connector.last_indexed_at.tzinfo
else connector.last_indexed_at else connector.last_indexed_at
) )
# Allow future dates - use last_indexed_at as start date
calculated_start_date = last_indexed_naive calculated_start_date = last_indexed_naive
logger.info( logger.info(
f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date" f"Using last_indexed_at ({calculated_start_date.strftime('%Y-%m-%d')}) as start date"
) )
else: else:
calculated_start_date = datetime.now() - timedelta( calculated_start_date = datetime.now() - timedelta(days=365)
days=365
) # Use 365 days as default for calendar events (matches frontend)
logger.info( logger.info(
f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date" f"No last_indexed_at found, using {calculated_start_date.strftime('%Y-%m-%d')} (365 days ago) as start date"
) )
# Use calculated dates if not provided
start_date_str = ( start_date_str = (
start_date if start_date else calculated_start_date.strftime("%Y-%m-%d") start_date if start_date else calculated_start_date.strftime("%Y-%m-%d")
) )
@ -240,19 +268,14 @@ async def index_google_calendar_events(
end_date if end_date else calculated_end_date.strftime("%Y-%m-%d") end_date if end_date else calculated_end_date.strftime("%Y-%m-%d")
) )
else: else:
# Use provided dates (including future dates)
start_date_str = start_date start_date_str = start_date
end_date_str = end_date end_date_str = end_date
# FIX: Ensure end_date is at least 1 day after start_date to avoid
# "start_date must be strictly before end_date" errors when dates are the same
# (e.g., when last_indexed_at is today)
if start_date_str == end_date_str: if start_date_str == end_date_str:
logger.info( logger.info(
f"Start date ({start_date_str}) equals end date ({end_date_str}), " f"Start date ({start_date_str}) equals end date ({end_date_str}), "
"adjusting end date to next day to ensure valid date range" "adjusting end date to next day to ensure valid date range"
) )
# Parse end_date and add 1 day
try: try:
end_dt = parse_date_flexible(end_date_str) end_dt = parse_date_flexible(end_date_str)
except ValueError: except ValueError:
@ -264,6 +287,7 @@ async def index_google_calendar_events(
end_date_str = end_dt.strftime("%Y-%m-%d") end_date_str = end_dt.strftime("%Y-%m-%d")
logger.info(f"Adjusted end date to {end_date_str}") logger.info(f"Adjusted end date to {end_date_str}")
# ── Fetch events ──────────────────────────────────────────────
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Fetching Google Calendar events from {start_date_str} to {end_date_str}", f"Fetching Google Calendar events from {start_date_str} to {end_date_str}",
@ -274,27 +298,19 @@ async def index_google_calendar_events(
}, },
) )
# Get events within date range from primary calendar
try: try:
events, error = await calendar_client.get_all_primary_calendar_events( events, error = await calendar_client.get_all_primary_calendar_events(
start_date=start_date_str, end_date=end_date_str start_date=start_date_str, end_date=end_date_str
) )
if error: if error:
# Don't treat "No events found" as an error that should stop indexing
if "No events found" in error: if "No events found" in error:
logger.info(f"No Google Calendar events found: {error}") logger.info(f"No Google Calendar events found: {error}")
logger.info(
"No events found is not a critical error, continuing with update"
)
if update_last_indexed: if update_last_indexed:
await update_connector_last_indexed( await update_connector_last_indexed(
session, connector, update_last_indexed session, connector, update_last_indexed
) )
await session.commit() await session.commit()
logger.info(
f"Updated last_indexed_at to {connector.last_indexed_at} despite no events found"
)
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
@ -304,7 +320,6 @@ async def index_google_calendar_events(
return 0, 0, None return 0, 0, None
else: else:
logger.error(f"Failed to get Google Calendar events: {error}") logger.error(f"Failed to get Google Calendar events: {error}")
# Check if this is an authentication error that requires re-authentication
error_message = error error_message = error
error_type = "APIError" error_type = "APIError"
if ( if (
@ -329,28 +344,15 @@ async def index_google_calendar_events(
logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True) logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True)
return 0, 0, f"Error fetching Google Calendar events: {e!s}" return 0, 0, f"Error fetching Google Calendar events: {e!s}"
documents_indexed = 0 # ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0 documents_skipped = 0
documents_failed = 0 # Track events that failed processing duplicate_content_count = 0
duplicate_content_count = (
0 # Track events skipped due to duplicate content_hash
)
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
# =======================================================================
# PHASE 1: Analyze all events, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
events_to_process = [] # List of dicts with document and event data
new_documents_created = False
for event in events: for event in events:
try: try:
event_id = event.get("id") event_id = event.get("id")
event_summary = event.get("summary", "No Title") event_summary = event.get("summary", "No Title")
calendar_id = event.get("calendarId", "")
if not event_id: if not event_id:
logger.warning(f"Skipping event with missing ID: {event_summary}") logger.warning(f"Skipping event with missing ID: {event_summary}")
@ -363,246 +365,55 @@ async def index_google_calendar_events(
documents_skipped += 1 documents_skipped += 1
continue continue
start = event.get("start", {}) doc = _build_connector_doc(
end = event.get("end", {}) event,
start_time = start.get("dateTime") or start.get("date", "") event_markdown,
end_time = end.get("dateTime") or end.get("date", "") connector_id=connector_id,
location = event.get("location", "") search_space_id=search_space_id,
description = event.get("description", "") user_id=user_id,
enable_summary=connector.enable_summary,
# Generate unique identifier hash for this Google Calendar event
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.GOOGLE_CALENDAR_CONNECTOR, event_id, search_space_id
) )
# Generate content hash
content_hash = generate_content_hash(event_markdown, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
# Fallback: legacy Composio hash
if not existing_document:
legacy_hash = generate_unique_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
event_id,
search_space_id,
)
existing_document = await check_document_by_unique_identifier(
session, legacy_hash
)
if existing_document:
existing_document.unique_identifier_hash = (
unique_identifier_hash
)
if (
existing_document.document_type
== DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
existing_document.document_type = (
DocumentType.GOOGLE_CALENDAR_CONNECTOR
)
logger.info(
f"Migrated legacy Composio Calendar document: {event_id}"
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
events_to_process.append(
{
"document": existing_document,
"is_new": False,
"event_markdown": event_markdown,
"content_hash": content_hash,
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"description": description,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush: with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash( duplicate = await check_duplicate_document_by_hash(
session, content_hash session, compute_content_hash(doc)
) )
if duplicate:
if duplicate_by_content:
# A document with the same content already exists (likely from Composio connector)
logger.info( logger.info(
f"Event {event_summary} already indexed by another connector " f"Event {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, " f"(existing document ID: {duplicate.id}, "
f"type: {duplicate_by_content.document_type}). Skipping to avoid duplicate content." f"type: {duplicate.document_type}). Skipping."
) )
duplicate_content_count += 1 duplicate_content_count += 1
documents_skipped += 1 documents_skipped += 1
continue continue
# Create new document with PENDING status (visible in UI immediately) connector_docs.append(doc)
document = Document(
search_space_id=search_space_id,
title=event_summary,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
document_metadata={
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
events_to_process.append(
{
"document": document,
"is_new": True,
"event_markdown": event_markdown,
"content_hash": content_hash,
"event_id": event_id,
"event_summary": event_summary,
"calendar_id": calendar_id,
"start_time": start_time,
"end_time": end_time,
"location": location,
"description": description,
}
)
except Exception as e: except Exception as e:
logger.error(f"Error in Phase 1 for event: {e!s}", exc_info=True) logger.error(f"Error building ConnectorDocument for event: {e!s}", exc_info=True)
documents_failed += 1 documents_skipped += 1
continue continue
# Commit all pending documents - they all appear in UI now # ── Pipeline: migrate legacy docs + parallel index ─────────────
if new_documents_created: pipeline = IndexingPipelineService(session)
logger.info(
f"Phase 1: Committing {len([e for e in events_to_process if e['is_new']])} pending documents"
)
await session.commit()
# ======================================================================= await pipeline.migrate_legacy_docs(connector_docs)
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(events_to_process)} documents")
for item in events_to_process: async def _get_llm(s):
# Send heartbeat periodically return await get_user_long_context_llm(s, user_id, search_space_id)
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"] _, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
try: connector_docs,
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only _get_llm,
document.status = DocumentStatus.processing() max_concurrency=3,
await session.commit() on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
) )
if user_llm and connector.enable_summary: # ── Finalize ──────────────────────────────────────────────────
document_metadata_for_summary = {
"event_id": item["event_id"],
"event_summary": item["event_summary"],
"calendar_id": item["calendar_id"],
"start_time": item["start_time"],
"end_time": item["end_time"],
"location": item["location"] or "No location",
"document_type": "Google Calendar Event",
"connector_type": "Google Calendar",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["event_markdown"], user_llm, document_metadata_for_summary
)
else:
summary_content = f"Google Calendar Event: {item['event_summary']}\n\n{item['event_markdown']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["event_markdown"])
# Update document to READY with actual content
document.title = item["event_summary"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"event_id": item["event_id"],
"event_summary": item["event_summary"],
"calendar_id": item["calendar_id"],
"start_time": item["start_time"],
"end_time": item["end_time"],
"location": item["location"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Google Calendar events processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Calendar event: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
await update_connector_last_indexed(session, connector, update_last_indexed) await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info( logger.info(
f"Final commit: Total {documents_indexed} Google Calendar events processed" f"Final commit: Total {documents_indexed} Google Calendar events processed"
) )
@ -612,22 +423,18 @@ async def index_google_calendar_events(
"Successfully committed all Google Calendar document changes to database" "Successfully committed all Google Calendar document changes to database"
) )
except Exception as e: except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if ( if (
"duplicate key value violates unique constraint" in str(e).lower() "duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower() or "uniqueviolationerror" in str(e).lower()
): ):
logger.warning( logger.warning(
f"Duplicate content_hash detected during final commit. " f"Duplicate content_hash detected during final commit. "
f"This may occur if the same event was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}" f"Rolling back and continuing. Error: {e!s}"
) )
await session.rollback() await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else: else:
raise raise
# Build warning message if there were issues
warning_parts = [] warning_parts = []
if duplicate_content_count > 0: if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate") warning_parts.append(f"{duplicate_content_count} duplicate")

View file

@ -1,12 +1,10 @@
""" """
Google Gmail connector indexer. Google Gmail connector indexer.
Implements 2-phase document status updates for real-time UI feedback: Uses the shared IndexingPipelineService for document deduplication,
- Phase 1: Create all documents with 'pending' status (visible in UI immediately) summarization, chunking, and embedding.
- Phase 2: Process each document: pending processing ready/failed
""" """
import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime from datetime import datetime
@ -15,21 +13,12 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.google_gmail_connector import GoogleGmailConnector from app.connectors.google_gmail_connector import GoogleGmailConnector
from app.db import ( from app.db import DocumentType, SearchSourceConnectorType
Document, from app.indexing_pipeline.connector_document import ConnectorDocument
DocumentStatus, from app.indexing_pipeline.document_hashing import compute_content_hash
DocumentType, from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
SearchSourceConnectorType,
)
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.google_credentials import ( from app.utils.google_credentials import (
COMPOSIO_GOOGLE_CONNECTOR_TYPES, COMPOSIO_GOOGLE_CONNECTOR_TYPES,
build_composio_credentials, build_composio_credentials,
@ -37,12 +26,9 @@ from app.utils.google_credentials import (
from .base import ( from .base import (
calculate_date_range, calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash, check_duplicate_document_by_hash,
get_connector_by_id, get_connector_by_id,
get_current_timestamp,
logger, logger,
safe_set_chunks,
update_connector_last_indexed, update_connector_last_indexed,
) )
@ -51,13 +37,70 @@ ACCEPTED_GMAIL_CONNECTOR_TYPES = {
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
} }
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]] HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30 HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
message: dict,
markdown_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Gmail API message dict to a ConnectorDocument."""
message_id = message.get("id", "")
thread_id = message.get("threadId", "")
payload = message.get("payload", {})
headers = payload.get("headers", [])
subject = "No Subject"
sender = "Unknown Sender"
date_str = "Unknown Date"
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
if name == "subject":
subject = value
elif name == "from":
sender = value
elif name == "date":
date_str = value
metadata = {
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date": date_str,
"connector_id": connector_id,
"document_type": "Gmail Message",
"connector_type": "Google Gmail",
}
fallback_summary = (
f"Google Gmail Message: {subject}\n\n"
f"From: {sender}\nDate: {date_str}\n\n"
f"{markdown_content}"
)
return ConnectorDocument(
title=subject,
source_markdown=markdown_content,
unique_id=message_id,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_google_gmail_messages( async def index_google_gmail_messages(
session: AsyncSession, session: AsyncSession,
connector_id: int, connector_id: int,
@ -80,7 +123,7 @@ async def index_google_gmail_messages(
start_date: Start date for filtering messages (YYYY-MM-DD format) start_date: Start date for filtering messages (YYYY-MM-DD format)
end_date: End date for filtering messages (YYYY-MM-DD format) end_date: End date for filtering messages (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True) update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
max_messages: Maximum number of messages to fetch (default: 100) max_messages: Maximum number of messages to fetch (default: 1000)
on_heartbeat_callback: Optional callback to update notification during long-running indexing. on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns: Returns:
@ -88,7 +131,6 @@ async def index_google_gmail_messages(
""" """
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start( log_entry = await task_logger.log_task_start(
task_name="google_gmail_messages_indexing", task_name="google_gmail_messages_indexing",
source="connector_indexing_task", source="connector_indexing_task",
@ -103,7 +145,7 @@ async def index_google_gmail_messages(
) )
try: try:
# Accept both native and Composio Gmail connectors # ── Connector lookup ──────────────────────────────────────────
connector = None connector = None
for ct in ACCEPTED_GMAIL_CONNECTOR_TYPES: for ct in ACCEPTED_GMAIL_CONNECTOR_TYPES:
connector = await get_connector_by_id(session, connector_id, ct) connector = await get_connector_by_id(session, connector_id, ct)
@ -117,7 +159,7 @@ async def index_google_gmail_messages(
) )
return 0, 0, error_msg return 0, 0, error_msg
# Build credentials based on connector type # ── Credential building ───────────────────────────────────────
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
connected_account_id = connector.config.get("composio_connected_account_id") connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id: if not connected_account_id:
@ -189,6 +231,7 @@ async def index_google_gmail_messages(
) )
return 0, 0, "Google gmail credentials not found in connector config" return 0, 0, "Google gmail credentials not found in connector config"
# ── Gmail client init ─────────────────────────────────────────
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Initializing Google gmail client for connector {connector_id}", f"Initializing Google gmail client for connector {connector_id}",
@ -199,14 +242,11 @@ async def index_google_gmail_messages(
credentials, session, user_id, connector_id credentials, session, user_id, connector_id
) )
# Calculate date range using last_indexed_at if dates not provided
# This ensures Gmail uses the same date logic as other connectors
# (uses last_indexed_at → now, or 365 days back for first-time indexing)
calculated_start_date, calculated_end_date = calculate_date_range( calculated_start_date, calculated_end_date = calculate_date_range(
connector, start_date, end_date, default_days_back=365 connector, start_date, end_date, default_days_back=365
) )
# Fetch recent Google gmail messages # ── Fetch messages ────────────────────────────────────────────
logger.info( logger.info(
f"Fetching emails for connector {connector_id} " f"Fetching emails for connector {connector_id} "
f"from {calculated_start_date} to {calculated_end_date}" f"from {calculated_start_date} to {calculated_end_date}"
@ -218,7 +258,6 @@ async def index_google_gmail_messages(
) )
if error: if error:
# Check if this is an authentication error that requires re-authentication
error_message = error error_message = error
error_type = "APIError" error_type = "APIError"
if ( if (
@ -243,286 +282,74 @@ async def index_google_gmail_messages(
logger.info(f"Found {len(messages)} Google gmail messages to index") logger.info(f"Found {len(messages)} Google gmail messages to index")
documents_indexed = 0 # ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0 documents_skipped = 0
documents_failed = 0 # Track messages that failed processing duplicate_content_count = 0
duplicate_content_count = (
0 # Track messages skipped due to duplicate content_hash
)
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
# =======================================================================
# PHASE 1: Analyze all messages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
messages_to_process = [] # List of dicts with document and message data
new_documents_created = False
for message in messages: for message in messages:
try: try:
# Extract message information
message_id = message.get("id", "") message_id = message.get("id", "")
thread_id = message.get("threadId", "")
# Extract headers for subject and sender
payload = message.get("payload", {})
headers = payload.get("headers", [])
subject = "No Subject"
sender = "Unknown Sender"
date_str = "Unknown Date"
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
if name == "subject":
subject = value
elif name == "from":
sender = value
elif name == "date":
date_str = value
if not message_id: if not message_id:
logger.warning(f"Skipping message with missing ID: {subject}") logger.warning("Skipping message with missing ID")
documents_skipped += 1 documents_skipped += 1
continue continue
# Format message to markdown
markdown_content = gmail_connector.format_message_to_markdown(message) markdown_content = gmail_connector.format_message_to_markdown(message)
if not markdown_content.strip(): if not markdown_content.strip():
logger.warning(f"Skipping message with no content: {subject}") logger.warning(f"Skipping message with no content: {message_id}")
documents_skipped += 1 documents_skipped += 1
continue continue
# Generate unique identifier hash for this Gmail message doc = _build_connector_doc(
unique_identifier_hash = generate_unique_identifier_hash( message,
DocumentType.GOOGLE_GMAIL_CONNECTOR, message_id, search_space_id markdown_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
) )
# Generate content hash
content_hash = generate_content_hash(markdown_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
# Fallback: legacy Composio hash
if not existing_document:
legacy_hash = generate_unique_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR,
message_id,
search_space_id,
)
existing_document = await check_document_by_unique_identifier(
session, legacy_hash
)
if existing_document:
existing_document.unique_identifier_hash = (
unique_identifier_hash
)
if (
existing_document.document_type
== DocumentType.COMPOSIO_GMAIL_CONNECTOR
):
existing_document.document_type = (
DocumentType.GOOGLE_GMAIL_CONNECTOR
)
logger.info(
f"Migrated legacy Composio Gmail document: {message_id}"
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
messages_to_process.append(
{
"document": existing_document,
"is_new": False,
"markdown_content": markdown_content,
"content_hash": content_hash,
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date_str": date_str,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush: with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash( duplicate = await check_duplicate_document_by_hash(
session, content_hash session, compute_content_hash(doc)
) )
if duplicate:
if duplicate_by_content:
logger.info( logger.info(
f"Gmail message {subject} already indexed by another connector " f"Gmail message {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, " f"(existing document ID: {duplicate.id}, "
f"type: {duplicate_by_content.document_type}). Skipping." f"type: {duplicate.document_type}). Skipping."
) )
duplicate_content_count += 1 duplicate_content_count += 1
documents_skipped += 1 documents_skipped += 1
continue continue
# Create new document with PENDING status (visible in UI immediately) connector_docs.append(doc)
document = Document(
search_space_id=search_space_id,
title=subject,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
document_metadata={
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date": date_str,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
messages_to_process.append(
{
"document": document,
"is_new": True,
"markdown_content": markdown_content,
"content_hash": content_hash,
"message_id": message_id,
"thread_id": thread_id,
"subject": subject,
"sender": sender,
"date_str": date_str,
}
)
except Exception as e: except Exception as e:
logger.error(f"Error in Phase 1 for message: {e!s}", exc_info=True) logger.error(f"Error building ConnectorDocument for message: {e!s}", exc_info=True)
documents_failed += 1 documents_skipped += 1
continue continue
# Commit all pending documents - they all appear in UI now # ── Pipeline: migrate legacy docs + parallel index ─────────────
if new_documents_created: pipeline = IndexingPipelineService(session)
logger.info(
f"Phase 1: Committing {len([m for m in messages_to_process if m['is_new']])} pending documents"
)
await session.commit()
# ======================================================================= await pipeline.migrate_legacy_docs(connector_docs)
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(messages_to_process)} documents")
for item in messages_to_process: async def _get_llm(s):
# Send heartbeat periodically return await get_user_long_context_llm(s, user_id, search_space_id)
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"] _, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
try: connector_docs,
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only _get_llm,
document.status = DocumentStatus.processing() max_concurrency=3,
await session.commit() on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
) )
if user_llm and connector.enable_summary: # ── Finalize ──────────────────────────────────────────────────
document_metadata_for_summary = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
"subject": item["subject"],
"sender": item["sender"],
"date": item["date_str"],
"document_type": "Gmail Message",
"connector_type": "Google Gmail",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["markdown_content"],
user_llm,
document_metadata_for_summary,
)
else:
summary_content = f"Google Gmail Message: {item['subject']}\n\nFrom: {item['sender']}\nDate: {item['date_str']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
# Update document to READY with actual content
document.title = item["subject"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"message_id": item["message_id"],
"thread_id": item["thread_id"],
"subject": item["subject"],
"sender": item["sender"],
"date": item["date_str"],
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Gmail messages processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Gmail message: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
documents_failed += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
await update_connector_last_indexed(session, connector, update_last_indexed) await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed") logger.info(f"Final commit: Total {documents_indexed} Gmail messages processed")
try: try:
await session.commit() await session.commit()
@ -530,22 +357,18 @@ async def index_google_gmail_messages(
"Successfully committed all Google Gmail document changes to database" "Successfully committed all Google Gmail document changes to database"
) )
except Exception as e: except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if ( if (
"duplicate key value violates unique constraint" in str(e).lower() "duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower() or "uniqueviolationerror" in str(e).lower()
): ):
logger.warning( logger.warning(
f"Duplicate content_hash detected during final commit. " f"Duplicate content_hash detected during final commit. "
f"This may occur if the same message was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}" f"Rolling back and continuing. Error: {e!s}"
) )
await session.rollback() await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else: else:
raise raise
# Build warning message if there were issues
warning_parts = [] warning_parts = []
if duplicate_content_count > 0: if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate") warning_parts.append(f"{duplicate_content_count} duplicate")
@ -555,7 +378,6 @@ async def index_google_gmail_messages(
total_processed = documents_indexed total_processed = documents_indexed
# Log success
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully completed Google Gmail indexing for connector {connector_id}", f"Successfully completed Google Gmail indexing for connector {connector_id}",

View file

@ -1,49 +1,80 @@
""" """Jira connector indexer using the unified parallel indexing pipeline."""
Jira connector indexer.
Provides real-time document status updates during indexing using a two-phase approach:
- Phase 1: Create all documents with PENDING status (visible in UI immediately)
- Phase 2: Process each document one by one (PENDING PROCESSING READY/FAILED)
"""
import contextlib import contextlib
import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.jira_history import JiraHistoryConnector from app.connectors.jira_history import JiraHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import ( from .base import (
calculate_date_range, calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash, check_duplicate_document_by_hash,
get_connector_by_id, get_connector_by_id,
get_current_timestamp,
logger, logger,
safe_set_chunks,
update_connector_last_indexed, update_connector_last_indexed,
) )
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]] HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30 HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
issue: dict,
formatted_issue: dict,
issue_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Jira issue dict to a ConnectorDocument."""
issue_id = issue.get("key", "")
issue_identifier = issue.get("key", "")
issue_title = issue.get("id", "")
state = formatted_issue.get("status", "Unknown")
priority = formatted_issue.get("priority", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
metadata = {
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"priority": priority,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Jira Issue",
"connector_type": "Jira",
}
fallback_summary = (
f"Jira Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n{issue_content}"
)
return ConnectorDocument(
title=f"{issue_identifier}: {issue_title}",
source_markdown=issue_content,
unique_id=issue_id,
document_type=DocumentType.JIRA_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_jira_issues( async def index_jira_issues(
session: AsyncSession, session: AsyncSession,
connector_id: int, connector_id: int,
@ -53,26 +84,9 @@ async def index_jira_issues(
end_date: str | None = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None, on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]: ) -> tuple[int, int, str | None]:
""" """Index Jira issues and comments."""
Index Jira issues and comments.
Args:
session: Database session
connector_id: ID of the Jira connector
search_space_id: ID of the search space to store documents in
user_id: User ID
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start( log_entry = await task_logger.log_task_start(
task_name="jira_issues_indexing", task_name="jira_issues_indexing",
source="connector_indexing_task", source="connector_indexing_task",
@ -86,7 +100,6 @@ async def index_jira_issues(
) )
try: try:
# Get the connector from the database
connector = await get_connector_by_id( connector = await get_connector_by_id(
session, connector_id, SearchSourceConnectorType.JIRA_CONNECTOR session, connector_id, SearchSourceConnectorType.JIRA_CONNECTOR
) )
@ -98,24 +111,15 @@ async def index_jira_issues(
"Connector not found", "Connector not found",
{"error_type": "ConnectorNotFound"}, {"error_type": "ConnectorNotFound"},
) )
return 0, f"Connector with ID {connector_id} not found" return 0, 0, f"Connector with ID {connector_id} not found"
# Initialize Jira client with internal refresh capability
# Token refresh will happen automatically when needed
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Initializing Jira client for connector {connector_id}", f"Initializing Jira client for connector {connector_id}",
{"stage": "client_initialization"}, {"stage": "client_initialization"},
) )
logger.info(f"Initializing Jira client for connector {connector_id}")
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
jira_client = JiraHistoryConnector(session=session, connector_id=connector_id) jira_client = JiraHistoryConnector(session=session, connector_id=connector_id)
# Calculate date range
# Handle "undefined" strings from frontend
if start_date == "undefined" or start_date == "": if start_date == "undefined" or start_date == "":
start_date = None start_date = None
if end_date == "undefined" or end_date == "": if end_date == "undefined" or end_date == "":
@ -135,19 +139,14 @@ async def index_jira_issues(
}, },
) )
# Get issues within date range
try: try:
issues, error = await jira_client.get_issues_by_date_range( issues, error = await jira_client.get_issues_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True start_date=start_date_str, end_date=end_date_str, include_comments=True
) )
if error: if error:
# Don't treat "No issues found" as an error that should stop indexing
if "No issues found" in error: if "No issues found" in error:
logger.info(f"No Jira issues found: {error}") logger.info(f"No Jira issues found: {error}")
logger.info(
"No issues found is not a critical error, continuing with update"
)
if update_last_indexed: if update_last_indexed:
await update_connector_last_indexed( await update_connector_last_indexed(
session, connector, update_last_indexed session, connector, update_last_indexed
@ -162,7 +161,8 @@ async def index_jira_issues(
f"No Jira issues found in date range {start_date_str} to {end_date_str}", f"No Jira issues found in date range {start_date_str} to {end_date_str}",
{"issues_found": 0}, {"issues_found": 0},
) )
return 0, None await jira_client.close()
return 0, 0, None
else: else:
logger.error(f"Failed to get Jira issues: {error}") logger.error(f"Failed to get Jira issues: {error}")
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -171,29 +171,30 @@ async def index_jira_issues(
"API Error", "API Error",
{"error_type": "APIError"}, {"error_type": "APIError"},
) )
return 0, f"Failed to get Jira issues: {error}" await jira_client.close()
return 0, 0, f"Failed to get Jira issues: {error}"
logger.info(f"Retrieved {len(issues)} issues from Jira API") logger.info(f"Retrieved {len(issues)} issues from Jira API")
except Exception as e: except Exception as e:
logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True) logger.error(f"Error fetching Jira issues: {e!s}", exc_info=True)
return 0, f"Error fetching Jira issues: {e!s}" await jira_client.close()
return 0, 0, f"Error fetching Jira issues: {e!s}"
# ======================================================================= if not issues:
# PHASE 1: Analyze all issues, create pending documents logger.info("No Jira issues found for the specified date range")
# This makes ALL documents visible in the UI immediately with pending status if update_last_indexed:
# ======================================================================= await update_connector_last_indexed(
documents_indexed = 0 session, connector, update_last_indexed
)
await session.commit()
await jira_client.close()
return 0, 0, None
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0 documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0 duplicate_content_count = 0
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
issues_to_process = [] # List of dicts with document and issue data
new_documents_created = False
for issue in issues: for issue in issues:
try: try:
issue_id = issue.get("key") issue_id = issue.get("key")
@ -207,10 +208,7 @@ async def index_jira_issues(
documents_skipped += 1 documents_skipped += 1
continue continue
# Format the issue for better readability
formatted_issue = jira_client.format_issue(issue) formatted_issue = jira_client.format_issue(issue)
# Convert to markdown
issue_content = jira_client.format_issue_to_markdown(formatted_issue) issue_content = jira_client.format_issue_to_markdown(formatted_issue)
if not issue_content: if not issue_content:
@ -220,53 +218,19 @@ async def index_jira_issues(
documents_skipped += 1 documents_skipped += 1
continue continue
# Generate unique identifier hash for this Jira issue doc = _build_connector_doc(
unique_identifier_hash = generate_unique_identifier_hash( issue,
DocumentType.JIRA_CONNECTOR, issue_id, search_space_id formatted_issue,
issue_content,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector.enable_summary,
) )
# Generate content hash
content_hash = generate_content_hash(issue_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
comment_count = len(formatted_issue.get("comments", []))
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
issues_to_process.append(
{
"document": existing_document,
"is_new": False,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"formatted_issue": formatted_issue,
"comment_count": comment_count,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush: with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash( duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash session, compute_content_hash(doc)
) )
if duplicate_by_content: if duplicate_by_content:
@ -279,160 +243,37 @@ async def index_jira_issues(
documents_skipped += 1 documents_skipped += 1
continue continue
# Create new document with PENDING status (visible in UI immediately) connector_docs.append(doc)
document = Document(
search_space_id=search_space_id,
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.JIRA_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": formatted_issue.get("status", "Unknown"),
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
issues_to_process.append(
{
"document": document,
"is_new": True,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"formatted_issue": formatted_issue,
"comment_count": comment_count,
}
)
except Exception as e:
logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(issues_to_process)} documents")
for item in issues_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata = {
"issue_key": item["issue_identifier"],
"issue_title": item["issue_title"],
"status": item["formatted_issue"].get("status", "Unknown"),
"priority": item["formatted_issue"].get("priority", "Unknown"),
"comment_count": item["comment_count"],
"document_type": "Jira Issue",
"connector_type": "Jira",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["issue_content"], user_llm, document_metadata
)
else:
summary_content = f"Jira Issue {item['issue_identifier']}: {item['issue_title']}\n\n{item['issue_content']}"
summary_embedding = embed_text(summary_content)
# Process chunks - using the full issue content with comments
chunks = await create_document_chunks(item["issue_content"])
# Update document to READY with actual content
document.title = f"{item['issue_identifier']}: {item['issue_title']}"
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"issue_id": item["issue_id"],
"issue_identifier": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["formatted_issue"].get("status", "Unknown"),
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Jira issues processed so far"
)
await session.commit()
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}", f"Error building ConnectorDocument for issue {issue_identifier}: {e!s}",
exc_info=True, exc_info=True,
) )
# Mark document as failed with reason (visible in UI) documents_skipped += 1
try: continue
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp() pipeline = IndexingPipelineService(session)
except Exception as status_error: await pipeline.migrate_legacy_docs(connector_docs)
logger.error(
f"Failed to update document status to failed: {status_error}" async def _get_llm(s: AsyncSession):
) return await get_user_long_context_llm(s, user_id, search_space_id)
documents_failed += 1
continue # Skip this issue and continue with others _, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
# This ensures the UI shows "Last indexed" instead of "Never indexed"
await update_connector_last_indexed(session, connector, update_last_indexed) await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit to ensure all documents are persisted (safety net)
logger.info(f"Final commit: Total {documents_indexed} Jira issues processed") logger.info(f"Final commit: Total {documents_indexed} Jira issues processed")
try: try:
await session.commit() await session.commit()
logger.info("Successfully committed all JIRA document changes to database") logger.info("Successfully committed all JIRA document changes to database")
except Exception as e: except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if ( if (
"duplicate key value violates unique constraint" in str(e).lower() "duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower() or "uniqueviolationerror" in str(e).lower()
@ -447,7 +288,6 @@ async def index_jira_issues(
else: else:
raise raise
# Build warning message if there were issues
warning_parts = [] warning_parts = []
if duplicate_content_count > 0: if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate") warning_parts.append(f"{duplicate_content_count} duplicate")
@ -455,7 +295,6 @@ async def index_jira_issues(
warning_parts.append(f"{documents_failed} failed") warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully completed JIRA indexing for connector {connector_id}", f"Successfully completed JIRA indexing for connector {connector_id}",
@ -466,17 +305,13 @@ async def index_jira_issues(
"duplicate_content_count": duplicate_content_count, "duplicate_content_count": duplicate_content_count,
}, },
) )
logger.info( logger.info(
f"JIRA indexing completed: {documents_indexed} ready, " f"JIRA indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed " f"{documents_skipped} skipped, {documents_failed} failed "
f"({duplicate_content_count} duplicate content)" f"({duplicate_content_count} duplicate content)"
) )
# Clean up the connector
await jira_client.close() await jira_client.close()
return documents_indexed, documents_skipped, warning_message
return documents_indexed, warning_message
except SQLAlchemyError as db_error: except SQLAlchemyError as db_error:
await session.rollback() await session.rollback()
@ -487,11 +322,10 @@ async def index_jira_issues(
{"error_type": "SQLAlchemyError"}, {"error_type": "SQLAlchemyError"},
) )
logger.error(f"Database error: {db_error!s}", exc_info=True) logger.error(f"Database error: {db_error!s}", exc_info=True)
# Clean up the connector in case of error
if "jira_client" in locals(): if "jira_client" in locals():
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await jira_client.close() await jira_client.close()
return 0, f"Database error: {db_error!s}" return 0, 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -501,8 +335,7 @@ async def index_jira_issues(
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True) logger.error(f"Failed to index JIRA issues: {e!s}", exc_info=True)
# Clean up the connector in case of error
if "jira_client" in locals(): if "jira_client" in locals():
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await jira_client.close() await jira_client.close()
return 0, f"Failed to index JIRA issues: {e!s}" return 0, 0, f"Failed to index JIRA issues: {e!s}"

View file

@ -1,48 +1,84 @@
""" """
Linear connector indexer. Linear connector indexer.
Implements 2-phase document status updates for real-time UI feedback: Uses the shared IndexingPipelineService for document deduplication,
- Phase 1: Create all documents with 'pending' status (visible in UI immediately) summarization, chunking, and embedding with bounded parallel indexing.
- Phase 2: Process each document: pending processing ready/failed
""" """
import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.linear_connector import LinearConnector from app.connectors.linear_connector import LinearConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import ( from .base import (
calculate_date_range, calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash, check_duplicate_document_by_hash,
get_connector_by_id, get_connector_by_id,
get_current_timestamp,
logger, logger,
safe_set_chunks,
update_connector_last_indexed, update_connector_last_indexed,
) )
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]] HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30 HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
issue: dict,
formatted_issue: dict,
issue_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Linear issue dict to a ConnectorDocument."""
issue_id = issue.get("id", "")
issue_identifier = issue.get("identifier", "")
issue_title = issue.get("title", "")
state = formatted_issue.get("state", "Unknown")
priority = formatted_issue.get("priority", "Unknown")
comment_count = len(formatted_issue.get("comments", []))
metadata = {
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"priority": priority,
"comment_count": comment_count,
"connector_id": connector_id,
"document_type": "Linear Issue",
"connector_type": "Linear",
}
fallback_summary = (
f"Linear Issue {issue_identifier}: {issue_title}\n\n"
f"Status: {state}\n\n{issue_content}"
)
return ConnectorDocument(
title=f"{issue_identifier}: {issue_title}",
source_markdown=issue_content,
unique_id=issue_id,
document_type=DocumentType.LINEAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_linear_issues( async def index_linear_issues(
session: AsyncSession, session: AsyncSession,
connector_id: int, connector_id: int,
@ -52,26 +88,15 @@ async def index_linear_issues(
end_date: str | None = None, end_date: str | None = None,
update_last_indexed: bool = True, update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None, on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]: ) -> tuple[int, int, str | None]:
""" """
Index Linear issues and comments. Index Linear issues and comments.
Args:
session: Database session
connector_id: ID of the Linear connector
search_space_id: ID of the search space to store documents in
user_id: ID of the user
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns: Returns:
Tuple containing (number of documents indexed, error message or None) Tuple of (indexed_count, skipped_count, warning_or_error_message)
""" """
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start( log_entry = await task_logger.log_task_start(
task_name="linear_issues_indexing", task_name="linear_issues_indexing",
source="connector_indexing_task", source="connector_indexing_task",
@ -85,7 +110,7 @@ async def index_linear_issues(
) )
try: try:
# Get the connector # ── Connector lookup ──────────────────────────────────────────
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Retrieving Linear connector {connector_id} from database", f"Retrieving Linear connector {connector_id} from database",
@ -104,11 +129,11 @@ async def index_linear_issues(
{"error_type": "ConnectorNotFound"}, {"error_type": "ConnectorNotFound"},
) )
return ( return (
0,
0, 0,
f"Connector with ID {connector_id} not found or is not a Linear connector", f"Connector with ID {connector_id} not found or is not a Linear connector",
) )
# Check if access_token exists (support both new OAuth format and old API key format)
if not connector.config.get("access_token") and not connector.config.get( if not connector.config.get("access_token") and not connector.config.get(
"LINEAR_API_KEY" "LINEAR_API_KEY"
): ):
@ -118,26 +143,22 @@ async def index_linear_issues(
"Missing Linear access token", "Missing Linear access token",
{"error_type": "MissingToken"}, {"error_type": "MissingToken"},
) )
return 0, "Linear access token not found in connector config" return 0, 0, "Linear access token not found in connector config"
# Initialize Linear client with internal refresh capability # ── Client init ───────────────────────────────────────────────
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Initializing Linear client for connector {connector_id}", f"Initializing Linear client for connector {connector_id}",
{"stage": "client_initialization"}, {"stage": "client_initialization"},
) )
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
linear_client = LinearConnector(session=session, connector_id=connector_id) linear_client = LinearConnector(session=session, connector_id=connector_id)
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "": if start_date == "undefined" or start_date == "":
start_date = None start_date = None
if end_date == "undefined" or end_date == "": if end_date == "undefined" or end_date == "":
end_date = None end_date = None
# Calculate date range
start_date_str, end_date_str = calculate_date_range( start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365 connector, start_date, end_date, default_days_back=365
) )
@ -154,37 +175,34 @@ async def index_linear_issues(
}, },
) )
# Get issues within date range # ── Fetch issues ──────────────────────────────────────────────
try: try:
issues, error = await linear_client.get_issues_by_date_range( issues, error = await linear_client.get_issues_by_date_range(
start_date=start_date_str, end_date=end_date_str, include_comments=True start_date=start_date_str,
end_date=end_date_str,
include_comments=True,
) )
if error: if error:
# Don't treat "No issues found" as an error that should stop indexing
if "No issues found" in error: if "No issues found" in error:
logger.info(f"No Linear issues found: {error}") logger.info(f"No Linear issues found: {error}")
logger.info(
"No issues found is not a critical error, continuing with update"
)
if update_last_indexed: if update_last_indexed:
await update_connector_last_indexed( await update_connector_last_indexed(
session, connector, update_last_indexed session, connector, update_last_indexed
) )
await session.commit() await session.commit()
logger.info( return 0, 0, None
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
)
return 0, None
else: else:
logger.error(f"Failed to get Linear issues: {error}") logger.error(f"Failed to get Linear issues: {error}")
return 0, f"Failed to get Linear issues: {error}" return 0, 0, f"Failed to get Linear issues: {error}"
logger.info(f"Retrieved {len(issues)} issues from Linear API") logger.info(f"Retrieved {len(issues)} issues from Linear API")
except Exception as e: except Exception as e:
logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True) logger.error(
return 0, f"Failed to get Linear issues: {e!s}" f"Exception when calling Linear API: {e!s}", exc_info=True
)
return 0, 0, f"Failed to get Linear issues: {e!s}"
if not issues: if not issues:
logger.info("No Linear issues found for the specified date range") logger.info("No Linear issues found for the specified date range")
@ -193,19 +211,12 @@ async def index_linear_issues(
session, connector, update_last_indexed session, connector, update_last_indexed
) )
await session.commit() await session.commit()
logger.info( return 0, 0, None
f"Updated last_indexed_at to {connector.last_indexed_at} despite no issues found"
)
return 0, None # Return None instead of error message when no issues found
# Track the number of documents indexed # ── Build ConnectorDocuments ──────────────────────────────────
documents_indexed = 0 connector_docs: list[ConnectorDocument] = []
documents_skipped = 0 documents_skipped = 0
documents_failed = 0 # Track issues that failed processing duplicate_content_count = 0
skipped_issues = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
@ -213,13 +224,6 @@ async def index_linear_issues(
{"stage": "process_issues", "total_issues": len(issues)}, {"stage": "process_issues", "total_issues": len(issues)},
) )
# =======================================================================
# PHASE 1: Analyze all issues, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
issues_to_process = [] # List of dicts with document and issue data
new_documents_created = False
for issue in issues: for issue in issues:
try: try:
issue_id = issue.get("id", "") issue_id = issue.get("id", "")
@ -230,271 +234,102 @@ async def index_linear_issues(
logger.warning( logger.warning(
f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}" f"Skipping issue with missing ID or title: {issue_id or 'Unknown'}"
) )
skipped_issues.append(
f"{issue_identifier or 'Unknown'} (missing data)"
)
documents_skipped += 1 documents_skipped += 1
continue continue
# Format the issue first to get well-structured data
formatted_issue = linear_client.format_issue(issue) formatted_issue = linear_client.format_issue(issue)
issue_content = linear_client.format_issue_to_markdown(
# Convert issue to markdown format formatted_issue
issue_content = linear_client.format_issue_to_markdown(formatted_issue) )
if not issue_content: if not issue_content:
logger.warning( logger.warning(
f"Skipping issue with no content: {issue_identifier} - {issue_title}" f"Skipping issue with no content: {issue_identifier} - {issue_title}"
) )
skipped_issues.append(f"{issue_identifier} (no content)")
documents_skipped += 1 documents_skipped += 1
continue continue
# Generate unique identifier hash for this Linear issue doc = _build_connector_doc(
unique_identifier_hash = generate_unique_identifier_hash( issue,
DocumentType.LINEAR_CONNECTOR, issue_id, search_space_id formatted_issue,
) issue_content,
# Generate content hash
content_hash = generate_content_hash(issue_content, search_space_id)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
state = formatted_issue.get("state", "Unknown")
description = formatted_issue.get("description", "")
comment_count = len(formatted_issue.get("comments", []))
priority = formatted_issue.get("priority", "Unknown")
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
logger.info(
f"Document for Linear issue {issue_identifier} unchanged. Skipping."
)
documents_skipped += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
issues_to_process.append(
{
"document": existing_document,
"is_new": False,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"description": description,
"comment_count": comment_count,
"priority": priority,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
)
if duplicate_by_content:
logger.info(
f"Linear issue {issue_identifier} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
documents_skipped += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=f"{issue_identifier}: {issue_title}",
document_type=DocumentType.LINEAR_CONNECTOR,
document_metadata={
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"comment_count": comment_count,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id, connector_id=connector_id,
) search_space_id=search_space_id,
session.add(document) user_id=user_id,
new_documents_created = True enable_summary=connector.enable_summary,
issues_to_process.append(
{
"document": document,
"is_new": True,
"issue_content": issue_content,
"content_hash": content_hash,
"issue_id": issue_id,
"issue_identifier": issue_identifier,
"issue_title": issue_title,
"state": state,
"description": description,
"comment_count": comment_count,
"priority": priority,
}
) )
except Exception as e: with session.no_autoflush:
logger.error(f"Error in Phase 1 for issue: {e!s}", exc_info=True) duplicate = await check_duplicate_document_by_hash(
documents_failed += 1 session, compute_content_hash(doc)
)
if duplicate:
logger.info(
f"Linear issue {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate.id}, "
f"type: {duplicate.document_type}). Skipping."
)
duplicate_content_count += 1
documents_skipped += 1
continue continue
# Commit all pending documents - they all appear in UI now connector_docs.append(doc)
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([i for i in issues_to_process if i['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(issues_to_process)} documents")
for item in issues_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"issue_id": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["state"],
"priority": item["priority"],
"comment_count": item["comment_count"],
"document_type": "Linear Issue",
"connector_type": "Linear",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["issue_content"], user_llm, document_metadata_for_summary
)
else:
summary_content = f"Linear Issue {item['issue_identifier']}: {item['issue_title']}\n\nStatus: {item['state']}\n\n{item['issue_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["issue_content"])
# Update document to READY with actual content
document.title = f"{item['issue_identifier']}: {item['issue_title']}"
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"issue_id": item["issue_id"],
"issue_identifier": item["issue_identifier"],
"issue_title": item["issue_title"],
"state": item["state"],
"comment_count": item["comment_count"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Linear issues processed so far"
)
await session.commit()
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error processing issue {item.get('issue_identifier', 'Unknown')}: {e!s}", f"Error building ConnectorDocument for issue: {e!s}",
exc_info=True, exc_info=True,
) )
# Mark document as failed with reason (visible in UI) documents_skipped += 1
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
skipped_issues.append(
f"{item.get('issue_identifier', 'Unknown')} (processing error)"
)
documents_failed += 1
continue continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs # ── Pipeline: migrate legacy docs + parallel index ────────────
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed) await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches logger.info(
logger.info(f"Final commit: Total {documents_indexed} Linear issues processed") f"Final commit: Total {documents_indexed} Linear issues processed"
)
try: try:
await session.commit() await session.commit()
logger.info( logger.info(
"Successfully committed all Linear document changes to database" "Successfully committed all Linear document changes to database"
) )
except Exception as e: except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if ( if (
"duplicate key value violates unique constraint" in str(e).lower() "duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower() or "uniqueviolationerror" in str(e).lower()
): ):
logger.warning( logger.warning(
f"Duplicate content_hash detected during final commit. " f"Duplicate content_hash detected during final commit. "
f"This may occur if the same issue was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}" f"Rolling back and continuing. Error: {e!s}"
) )
await session.rollback() await session.rollback()
else: else:
raise raise
# Build warning message if there were issues warning_parts: list[str] = []
warning_parts = [] if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0: if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed") warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None warning_message = ", ".join(warning_parts) if warning_parts else None
# Log success
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully completed Linear indexing for connector {connector_id}", f"Successfully completed Linear indexing for connector {connector_id}",
@ -503,7 +338,7 @@ async def index_linear_issues(
"documents_indexed": documents_indexed, "documents_indexed": documents_indexed,
"documents_skipped": documents_skipped, "documents_skipped": documents_skipped,
"documents_failed": documents_failed, "documents_failed": documents_failed,
"skipped_issues_count": len(skipped_issues), "duplicate_content_count": duplicate_content_count,
}, },
) )
@ -511,7 +346,7 @@ async def index_linear_issues(
f"Linear indexing completed: {documents_indexed} ready, " f"Linear indexing completed: {documents_indexed} ready, "
f"{documents_skipped} skipped, {documents_failed} failed" f"{documents_skipped} skipped, {documents_failed} failed"
) )
return documents_indexed, warning_message return documents_indexed, documents_skipped, warning_message
except SQLAlchemyError as db_error: except SQLAlchemyError as db_error:
await session.rollback() await session.rollback()
@ -522,7 +357,7 @@ async def index_linear_issues(
{"error_type": "SQLAlchemyError"}, {"error_type": "SQLAlchemyError"},
) )
logger.error(f"Database error: {db_error!s}", exc_info=True) logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {db_error!s}" return 0, 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -532,4 +367,4 @@ async def index_linear_issues(
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True) logger.error(f"Failed to index Linear issues: {e!s}", exc_info=True)
return 0, f"Failed to index Linear issues: {e!s}" return 0, 0, f"Failed to index Linear issues: {e!s}"

View file

@ -1,12 +1,10 @@
""" """
Notion connector indexer. Notion connector indexer.
Implements real-time document status updates using a two-phase approach: Uses the shared IndexingPipelineService for document deduplication,
- Phase 1: Create all documents with PENDING status (visible in UI immediately) summarization, chunking, and embedding with bounded parallel indexing.
- Phase 2: Process each document one by one (pending processing ready/failed)
""" """
import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime from datetime import datetime
@ -14,42 +12,64 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.connectors.notion_history import NotionHistoryConnector from app.connectors.notion_history import NotionHistoryConnector
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.services.llm_service import get_user_long_context_llm from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from app.utils.notion_utils import process_blocks from app.utils.notion_utils import process_blocks
from .base import ( from .base import (
build_document_metadata_string,
calculate_date_range, calculate_date_range,
check_document_by_unique_identifier,
check_duplicate_document_by_hash, check_duplicate_document_by_hash,
get_connector_by_id, get_connector_by_id,
get_current_timestamp,
logger, logger,
safe_set_chunks,
update_connector_last_indexed, update_connector_last_indexed,
) )
# Type alias for retry callback
# Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds) -> None
RetryCallbackType = Callable[[str, int, int, float], Awaitable[None]] RetryCallbackType = Callable[[str, int, int, float], Awaitable[None]]
# Type alias for heartbeat callback
# Signature: async callback(indexed_count) -> None
HeartbeatCallbackType = Callable[[int], Awaitable[None]] HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds - update notification every 30 seconds
HEARTBEAT_INTERVAL_SECONDS = 30 HEARTBEAT_INTERVAL_SECONDS = 30
def _build_connector_doc(
page: dict,
markdown_content: str,
*,
connector_id: int,
search_space_id: int,
user_id: str,
enable_summary: bool,
) -> ConnectorDocument:
"""Map a raw Notion page dict to a ConnectorDocument."""
page_id = page.get("page_id", "")
page_title = page.get("title", f"Untitled page ({page_id})")
metadata = {
"page_title": page_title,
"page_id": page_id,
"connector_id": connector_id,
"document_type": "Notion Page",
"connector_type": "Notion",
}
fallback_summary = f"Notion Page: {page_title}\n\n{markdown_content}"
return ConnectorDocument(
title=page_title,
source_markdown=markdown_content,
unique_id=page_id,
document_type=DocumentType.NOTION_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=enable_summary,
fallback_summary=fallback_summary,
metadata=metadata,
)
async def index_notion_pages( async def index_notion_pages(
session: AsyncSession, session: AsyncSession,
connector_id: int, connector_id: int,
@ -60,30 +80,15 @@ async def index_notion_pages(
update_last_indexed: bool = True, update_last_indexed: bool = True,
on_retry_callback: RetryCallbackType | None = None, on_retry_callback: RetryCallbackType | None = None,
on_heartbeat_callback: HeartbeatCallbackType | None = None, on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]: ) -> tuple[int, int, str | None]:
""" """
Index Notion pages from all accessible pages. Index Notion pages from all accessible pages.
Args:
session: Database session
connector_id: ID of the Notion connector
search_space_id: ID of the search space to store documents in
user_id: ID of the user
start_date: Start date for indexing (YYYY-MM-DD format)
end_date: End date for indexing (YYYY-MM-DD format)
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
on_retry_callback: Optional callback for retry progress notifications.
Signature: async callback(retry_reason, attempt, max_attempts, wait_seconds)
retry_reason is one of: 'rate_limit', 'server_error', 'timeout'
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Called periodically with (indexed_count) to prevent task appearing stuck.
Returns: Returns:
Tuple containing (number of documents indexed, error message or None) Tuple of (indexed_count, skipped_count, warning_or_error_message)
""" """
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start( log_entry = await task_logger.log_task_start(
task_name="notion_pages_indexing", task_name="notion_pages_indexing",
source="connector_indexing_task", source="connector_indexing_task",
@ -97,7 +102,7 @@ async def index_notion_pages(
) )
try: try:
# Get the connector # ── Connector lookup ──────────────────────────────────────────
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Retrieving Notion connector {connector_id} from database", f"Retrieving Notion connector {connector_id} from database",
@ -116,11 +121,11 @@ async def index_notion_pages(
{"error_type": "ConnectorNotFound"}, {"error_type": "ConnectorNotFound"},
) )
return ( return (
0,
0, 0,
f"Connector with ID {connector_id} not found or is not a Notion connector", f"Connector with ID {connector_id} not found or is not a Notion connector",
) )
# Check if access_token exists (support both new OAuth format and old integration token format)
if not connector.config.get("access_token") and not connector.config.get( if not connector.config.get("access_token") and not connector.config.get(
"NOTION_INTEGRATION_TOKEN" "NOTION_INTEGRATION_TOKEN"
): ):
@ -130,9 +135,9 @@ async def index_notion_pages(
"Missing Notion access token", "Missing Notion access token",
{"error_type": "MissingToken"}, {"error_type": "MissingToken"},
) )
return 0, "Notion access token not found in connector config" return 0, 0, "Notion access token not found in connector config"
# Initialize Notion client with internal refresh capability # ── Client init ───────────────────────────────────────────────
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
f"Initializing Notion client for connector {connector_id}", f"Initializing Notion client for connector {connector_id}",
@ -141,18 +146,15 @@ async def index_notion_pages(
logger.info(f"Initializing Notion client for connector {connector_id}") logger.info(f"Initializing Notion client for connector {connector_id}")
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "": if start_date == "undefined" or start_date == "":
start_date = None start_date = None
if end_date == "undefined" or end_date == "": if end_date == "undefined" or end_date == "":
end_date = None end_date = None
# Calculate date range using the shared utility function
start_date_str, end_date_str = calculate_date_range( start_date_str, end_date_str = calculate_date_range(
connector, start_date, end_date, default_days_back=365 connector, start_date, end_date, default_days_back=365
) )
# Convert YYYY-MM-DD to ISO format for Notion API
start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime( start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime(
"%Y-%m-%dT%H:%M:%SZ" "%Y-%m-%dT%H:%M:%SZ"
) )
@ -160,13 +162,10 @@ async def index_notion_pages(
"%Y-%m-%dT%H:%M:%SZ" "%Y-%m-%dT%H:%M:%SZ"
) )
# Create connector with session and connector_id for internal refresh
# Token refresh will happen automatically when needed
notion_client = NotionHistoryConnector( notion_client = NotionHistoryConnector(
session=session, connector_id=connector_id session=session, connector_id=connector_id
) )
# Set retry callback if provided (for user notifications during rate limits)
if on_retry_callback: if on_retry_callback:
notion_client.set_retry_callback(on_retry_callback) notion_client.set_retry_callback(on_retry_callback)
@ -182,21 +181,19 @@ async def index_notion_pages(
}, },
) )
# Get all pages # ── Fetch pages ───────────────────────────────────────────────
try: try:
pages = await notion_client.get_all_pages( pages = await notion_client.get_all_pages(
start_date=start_date_iso, end_date=end_date_iso start_date=start_date_iso, end_date=end_date_iso
) )
logger.info(f"Found {len(pages)} Notion pages") logger.info(f"Found {len(pages)} Notion pages")
# Get count of pages that had unsupported content skipped
pages_with_skipped_content = notion_client.get_skipped_content_count() pages_with_skipped_content = notion_client.get_skipped_content_count()
if pages_with_skipped_content > 0: if pages_with_skipped_content > 0:
logger.info( logger.info(
f"{pages_with_skipped_content} pages had Notion AI content skipped (not available via API)" f"{pages_with_skipped_content} pages had Notion AI content skipped (not available via API)"
) )
# Check if using legacy integration token and log warning
if notion_client.is_using_legacy_token(): if notion_client.is_using_legacy_token():
logger.warning( logger.warning(
f"Connector {connector_id} is using legacy integration token. " f"Connector {connector_id} is using legacy integration token. "
@ -204,8 +201,6 @@ async def index_notion_pages(
) )
except Exception as e: except Exception as e:
error_str = str(e) error_str = str(e)
# Check if this is an unsupported block type error (transcription, ai_block, etc.)
# These are known Notion API limitations and should be logged as warnings, not errors
unsupported_block_errors = [ unsupported_block_errors = [
"transcription is not supported", "transcription is not supported",
"ai_block is not supported", "ai_block is not supported",
@ -216,7 +211,6 @@ async def index_notion_pages(
) )
if is_unsupported_block_error: if is_unsupported_block_error:
# Log as warning since this is a known Notion API limitation
logger.warning( logger.warning(
f"Notion API limitation for connector {connector_id}: {error_str}. " f"Notion API limitation for connector {connector_id}: {error_str}. "
"This is a known issue with Notion AI blocks (transcription, ai_block) " "This is a known issue with Notion AI blocks (transcription, ai_block) "
@ -229,7 +223,6 @@ async def index_notion_pages(
{"error_type": "UnsupportedBlockType", "is_known_limitation": True}, {"error_type": "UnsupportedBlockType", "is_known_limitation": True},
) )
else: else:
# Log as error for other failures
logger.error( logger.error(
f"Error fetching Notion pages for connector {connector_id}: {error_str}", f"Error fetching Notion pages for connector {connector_id}: {error_str}",
exc_info=True, exc_info=True,
@ -242,7 +235,7 @@ async def index_notion_pages(
) )
await notion_client.close() await notion_client.close()
return 0, f"Failed to get Notion pages: {e!s}" return 0, 0, f"Failed to get Notion pages: {e!s}"
if not pages: if not pages:
await task_logger.log_task_success( await task_logger.log_task_success(
@ -252,21 +245,17 @@ async def index_notion_pages(
{"pages_found": 0}, {"pages_found": 0},
) )
logger.info("No Notion pages found to index") logger.info("No Notion pages found to index")
# CRITICAL: Update timestamp even when no pages found so Zero syncs await update_connector_last_indexed(
await update_connector_last_indexed(session, connector, update_last_indexed) session, connector, update_last_indexed
)
await session.commit() await session.commit()
await notion_client.close() await notion_client.close()
return 0, None # Success with 0 pages, not an error return 0, 0, None
# Track the number of documents indexed # ── Build ConnectorDocuments ──────────────────────────────────
documents_indexed = 0 connector_docs: list[ConnectorDocument] = []
documents_skipped = 0 documents_skipped = 0
documents_failed = 0
duplicate_content_count = 0 duplicate_content_count = 0
skipped_pages = []
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
await task_logger.log_task_progress( await task_logger.log_task_progress(
log_entry, log_entry,
@ -274,13 +263,6 @@ async def index_notion_pages(
{"stage": "process_pages", "total_pages": len(pages)}, {"stage": "process_pages", "total_pages": len(pages)},
) )
# =======================================================================
# PHASE 1: Analyze all pages, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
pages_to_process = [] # List of dicts with document and page data
new_documents_created = False
for page in pages: for page in pages:
try: try:
page_id = page.get("page_id") page_id = page.get("page_id")
@ -293,225 +275,71 @@ async def index_notion_pages(
if not page_content: if not page_content:
logger.info(f"No content found in page {page_title}. Skipping.") logger.info(f"No content found in page {page_title}. Skipping.")
skipped_pages.append(f"{page_title} (no content)")
documents_skipped += 1 documents_skipped += 1
continue continue
# Convert page content to markdown format
markdown_content = f"# Notion Page: {page_title}\n\n" markdown_content = f"# Notion Page: {page_title}\n\n"
markdown_content += process_blocks(page_content) markdown_content += process_blocks(page_content)
# Format document metadata if not markdown_content.strip():
metadata_sections = [ logger.warning(
("METADATA", [f"PAGE_TITLE: {page_title}", f"PAGE_ID: {page_id}"]), f"Skipping page with empty markdown: {page_title}"
(
"CONTENT",
[
"FORMAT: markdown",
"TEXT_START",
markdown_content,
"TEXT_END",
],
),
]
# Build the document string
combined_document_string = build_document_metadata_string(
metadata_sections
) )
# Generate unique identifier hash for this Notion page
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTION_CONNECTOR, page_id, search_space_id
)
# Generate content hash
content_hash = generate_content_hash(
combined_document_string, search_space_id
)
# Check if document with this unique identifier already exists
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
documents_skipped += 1 documents_skipped += 1
continue continue
# Queue existing document for update (will be set to processing in Phase 2) doc = _build_connector_doc(
pages_to_process.append( page,
{ markdown_content,
"document": existing_document, connector_id=connector_id,
"is_new": False, search_space_id=search_space_id,
"markdown_content": markdown_content, user_id=user_id,
"content_hash": content_hash, enable_summary=connector.enable_summary,
"page_id": page_id,
"page_title": page_title,
}
) )
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush: with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash( duplicate = await check_duplicate_document_by_hash(
session, content_hash session, compute_content_hash(doc)
) )
if duplicate:
if duplicate_by_content:
logger.info( logger.info(
f"Notion page {page_title} already indexed by another connector " f"Notion page {doc.title} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, " f"(existing document ID: {duplicate.id}, "
f"type: {duplicate_by_content.document_type}). Skipping." f"type: {duplicate.document_type}). Skipping."
) )
duplicate_content_count += 1 duplicate_content_count += 1
documents_skipped += 1 documents_skipped += 1
continue continue
# Create new document with PENDING status (visible in UI immediately) connector_docs.append(doc)
document = Document(
search_space_id=search_space_id,
title=page_title,
document_type=DocumentType.NOTION_CONNECTOR,
document_metadata={
"page_title": page_title,
"page_id": page_id,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
pages_to_process.append(
{
"document": document,
"is_new": True,
"markdown_content": markdown_content,
"content_hash": content_hash,
"page_id": page_id,
"page_title": page_title,
}
)
except Exception as e: except Exception as e:
logger.error(f"Error in Phase 1 for page: {e!s}", exc_info=True)
documents_failed += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([p for p in pages_to_process if p['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(pages_to_process)} documents")
for item in pages_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(documents_indexed)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Heavy processing (LLM, embeddings, chunks)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm and connector.enable_summary:
document_metadata_for_summary = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"document_type": "Notion Page",
"connector_type": "Notion",
}
(
summary_content,
summary_embedding,
) = await generate_document_summary(
item["markdown_content"],
user_llm,
document_metadata_for_summary,
)
else:
summary_content = f"Notion Page: {item['page_title']}\n\n{item['markdown_content']}"
summary_embedding = embed_text(summary_content)
chunks = await create_document_chunks(item["markdown_content"])
# Update document to READY with actual content
document.title = item["page_title"]
document.content = summary_content
document.content_hash = item["content_hash"]
document.embedding = summary_embedding
document.document_metadata = {
"page_title": item["page_title"],
"page_id": item["page_id"],
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"connector_id": connector_id,
}
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
documents_indexed += 1
# Batch commit every 10 documents (for ready status updates)
if documents_indexed % 10 == 0:
logger.info(
f"Committing batch: {documents_indexed} Notion pages processed so far"
)
await session.commit()
except Exception as e:
logger.error(f"Error processing Notion page: {e!s}", exc_info=True)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error( logger.error(
f"Failed to update document status to failed: {status_error}" f"Error building ConnectorDocument for page: {e!s}",
exc_info=True,
) )
skipped_pages.append(f"{item['page_title']} (processing error)") documents_skipped += 1
documents_failed += 1
continue continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs # ── Pipeline: migrate legacy docs + parallel index ────────────
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
return await get_user_long_context_llm(s, user_id, search_space_id)
_, documents_indexed, documents_failed = await pipeline.index_batch_parallel(
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat_callback,
heartbeat_interval=HEARTBEAT_INTERVAL_SECONDS,
)
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed) await update_connector_last_indexed(session, connector, update_last_indexed)
total_processed = documents_indexed
# Final commit to ensure all documents are persisted (safety net)
logger.info(f"Final commit: Total {documents_indexed} documents processed") logger.info(f"Final commit: Total {documents_indexed} documents processed")
try: try:
await session.commit() await session.commit()
@ -519,59 +347,53 @@ async def index_notion_pages(
"Successfully committed all Notion document changes to database" "Successfully committed all Notion document changes to database"
) )
except Exception as e: except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if ( if (
"duplicate key value violates unique constraint" in str(e).lower() "duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower() or "uniqueviolationerror" in str(e).lower()
): ):
logger.warning( logger.warning(
f"Duplicate content_hash detected during final commit. " f"Duplicate content_hash detected during final commit. "
f"This may occur if the same page was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}" f"Rolling back and continuing. Error: {e!s}"
) )
await session.rollback() await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else: else:
raise raise
# Get final count of pages with skipped Notion AI content # ── Build warning / notification message ──────────────────────
pages_with_skipped_ai_content = notion_client.get_skipped_content_count() pages_with_skipped_ai_content = notion_client.get_skipped_content_count()
# Build warning message if there were issues warning_parts: list[str] = []
warning_parts = []
if duplicate_content_count > 0: if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate") warning_parts.append(f"{duplicate_content_count} duplicate")
if documents_failed > 0: if documents_failed > 0:
warning_parts.append(f"{documents_failed} failed") warning_parts.append(f"{documents_failed} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
# Prepare result message with user-friendly notification about skipped content notification_parts: list[str] = []
result_message = None
if skipped_pages:
result_message = f"Processed {total_processed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
else:
result_message = f"Processed {total_processed} pages."
# Add user-friendly message about skipped Notion AI content
if pages_with_skipped_ai_content > 0: if pages_with_skipped_ai_content > 0:
result_message += ( notification_parts.append(
" Audio transcriptions and AI summaries from Notion aren't accessible " "Some Notion AI content couldn't be synced (API limitation)"
"via their API - all other content was saved." )
if notion_client.is_using_legacy_token():
notification_parts.append(
"Using legacy token. Reconnect with OAuth for better reliability."
)
if warning_parts:
notification_parts.append(", ".join(warning_parts))
user_notification_message = (
" ".join(notification_parts) if notification_parts else None
) )
# Log success
await task_logger.log_task_success( await task_logger.log_task_success(
log_entry, log_entry,
f"Successfully completed Notion indexing for connector {connector_id}", f"Successfully completed Notion indexing for connector {connector_id}",
{ {
"pages_processed": total_processed, "pages_processed": documents_indexed,
"documents_indexed": documents_indexed, "documents_indexed": documents_indexed,
"documents_skipped": documents_skipped, "documents_skipped": documents_skipped,
"documents_failed": documents_failed, "documents_failed": documents_failed,
"duplicate_content_count": duplicate_content_count, "duplicate_content_count": duplicate_content_count,
"skipped_pages_count": len(skipped_pages),
"pages_with_skipped_ai_content": pages_with_skipped_ai_content, "pages_with_skipped_ai_content": pages_with_skipped_ai_content,
"result_message": result_message,
}, },
) )
@ -581,35 +403,9 @@ async def index_notion_pages(
f"({duplicate_content_count} duplicate content)" f"({duplicate_content_count} duplicate content)"
) )
# Clean up the async client
await notion_client.close() await notion_client.close()
# Build user-friendly notification messages return documents_indexed, documents_skipped, user_notification_message
# This will be shown in the notification to inform users
notification_parts = []
if pages_with_skipped_ai_content > 0:
notification_parts.append(
"Some Notion AI content couldn't be synced (API limitation)"
)
if notion_client.is_using_legacy_token():
notification_parts.append(
"Using legacy token. Reconnect with OAuth for better reliability."
)
# Include warning message if there were issues
if warning_message:
notification_parts.append(warning_message)
user_notification_message = (
" ".join(notification_parts) if notification_parts else None
)
return (
total_processed,
user_notification_message,
)
except SQLAlchemyError as db_error: except SQLAlchemyError as db_error:
await session.rollback() await session.rollback()
@ -622,10 +418,9 @@ async def index_notion_pages(
logger.error( logger.error(
f"Database error during Notion indexing: {db_error!s}", exc_info=True f"Database error during Notion indexing: {db_error!s}", exc_info=True
) )
# Clean up the async client in case of error
if "notion_client" in locals(): if "notion_client" in locals():
await notion_client.close() await notion_client.close()
return 0, f"Database error: {db_error!s}" return 0, 0, f"Database error: {db_error!s}"
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
await task_logger.log_task_failure( await task_logger.log_task_failure(
@ -635,7 +430,6 @@ async def index_notion_pages(
{"error_type": type(e).__name__}, {"error_type": type(e).__name__},
) )
logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True) logger.error(f"Failed to index Notion pages: {e!s}", exc_info=True)
# Clean up the async client in case of error
if "notion_client" in locals(): if "notion_client" in locals():
await notion_client.close() await notion_client.close()
return 0, f"Failed to index Notion pages: {e!s}" return 0, 0, f"Failed to index Notion pages: {e!s}"

View file

@ -1,5 +1,6 @@
import hashlib import hashlib
import logging import logging
import threading
import warnings import warnings
import numpy as np import numpy as np
@ -11,6 +12,12 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# HuggingFace fast tokenizers (Rust-backed) are not thread-safe — concurrent
# access from multiple threads causes "RuntimeError: Already borrowed".
# This reentrant lock serialises tokenizer + embedding model access so that
# asyncio.to_thread calls from index_batch_parallel don't collide.
_embedding_lock = threading.RLock()
def _get_embedding_max_tokens() -> int: def _get_embedding_max_tokens() -> int:
"""Get the max token limit for the configured embedding model. """Get the max token limit for the configured embedding model.
@ -36,6 +43,7 @@ def truncate_for_embedding(text: str) -> str:
if len(text) // 3 <= max_tokens: if len(text) // 3 <= max_tokens:
return text return text
with _embedding_lock:
tokenizer = config.embedding_model_instance.get_tokenizer() tokenizer = config.embedding_model_instance.get_tokenizer()
tokens = tokenizer.encode(text) tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens: if len(tokens) <= max_tokens:
@ -52,6 +60,7 @@ def embed_text(text: str) -> np.ndarray:
"""Truncate text to fit and embed it. Drop-in replacement for """Truncate text to fit and embed it. Drop-in replacement for
``config.embedding_model_instance.embed(text)`` that never exceeds the ``config.embedding_model_instance.embed(text)`` that never exceeds the
model's context window.""" model's context window."""
with _embedding_lock:
return config.embedding_model_instance.embed(truncate_for_embedding(text)) return config.embedding_model_instance.embed(truncate_for_embedding(text))
@ -66,6 +75,7 @@ def embed_texts(texts: list[str]) -> list[np.ndarray]:
""" """
if not texts: if not texts:
return [] return []
with _embedding_lock:
truncated = [truncate_for_embedding(t) for t in texts] truncated = [truncate_for_embedding(t) for t in texts]
if config.is_local_embedding_model: if config.is_local_embedding_model:
return [config.embedding_model_instance.embed(t) for t in truncated] return [config.embedding_model_instance.embed(t) for t in truncated]

View file

@ -0,0 +1,111 @@
"""Integration tests: Calendar indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
return ConnectorDocument(
title=f"Event {unique_id}",
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"Calendar: Event {unique_id}",
metadata={
"event_id": unique_id,
"start_time": "2025-01-15T10:00:00",
"end_time": "2025-01-15T11:00:00",
"document_type": "Google Calendar Event",
},
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_calendar_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Calendar ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _cal_doc(
unique_id="evt-1",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_calendar_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Calendar doc is migrated and reused."""
space_id = db_search_space.id
user_id = str(db_user.id)
evt_id = "evt-legacy-cal"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id
)
legacy_doc = Document(
title="Old Calendar Event",
document_type=DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
content="old summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old event",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _cal_doc(
unique_id=evt_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == original_id))
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_CALENDAR_CONNECTOR.value, evt_id, space_id
)
assert row.unique_identifier_hash == native_hash

View file

@ -0,0 +1,170 @@
"""Integration tests: Drive indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
return ConnectorDocument(
title=f"File {unique_id}.pdf",
source_markdown=f"## Document Content\n\nText from file {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_DRIVE_FILE,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"File: {unique_id}.pdf",
metadata={
"google_drive_file_id": unique_id,
"google_drive_file_name": f"{unique_id}.pdf",
"document_type": "Google Drive File",
},
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_drive_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Drive ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _drive_doc(
unique_id="file-abc",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_drive_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Drive doc is migrated and reused."""
space_id = db_search_space.id
user_id = str(db_user.id)
file_id = "file-legacy-drive"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR.value, file_id, space_id
)
legacy_doc = Document(
title="Old Drive File",
document_type=DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
content="old file summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old file content",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _drive_doc(
unique_id=file_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == original_id))
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id
)
assert row.unique_identifier_hash == native_hash
async def test_should_skip_file_skips_failed_document(
db_session, db_search_space, db_user,
):
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
import importlib
import sys
import types
pkg = "app.tasks.connector_indexers"
stub = pkg not in sys.modules
if stub:
mod = types.ModuleType(pkg)
mod.__path__ = ["app/tasks/connector_indexers"]
mod.__package__ = pkg
sys.modules[pkg] = mod
try:
gdm = importlib.import_module(
"app.tasks.connector_indexers.google_drive_indexer"
)
_should_skip_file = gdm._should_skip_file
finally:
if stub:
sys.modules.pop(pkg, None)
space_id = db_search_space.id
file_id = "file-failed-drive"
md5 = "abc123deadbeef"
doc_hash = compute_identifier_hash(
DocumentType.GOOGLE_DRIVE_FILE.value, file_id, space_id
)
failed_doc = Document(
title="Failed File.pdf",
document_type=DocumentType.GOOGLE_DRIVE_FILE,
content="LLM rate limit exceeded",
content_hash=f"ch-{doc_hash[:12]}",
unique_identifier_hash=doc_hash,
source_markdown="## Real content",
search_space_id=space_id,
created_by_id=str(db_user.id),
embedding=[0.1] * _EMBEDDING_DIM,
status=DocumentStatus.failed("LLM rate limit exceeded"),
document_metadata={
"google_drive_file_id": file_id,
"google_drive_file_name": "Failed File.pdf",
"md5_checksum": md5,
},
)
db_session.add(failed_doc)
await db_session.flush()
incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5}
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)
assert should_skip, "FAILED documents must be skipped during automatic sync"
assert "failed" in msg.lower()

View file

@ -0,0 +1,116 @@
"""Integration tests: Gmail indexer builds ConnectorDocuments that flow through the pipeline."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import (
compute_identifier_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
return ConnectorDocument(
title=f"Subject for {unique_id}",
source_markdown=f"## Email\n\nBody of {unique_id}",
unique_id=unique_id,
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
should_summarize=True,
fallback_summary=f"Gmail: Subject for {unique_id}",
metadata={
"message_id": unique_id,
"from": "sender@example.com",
"document_type": "Gmail Message",
},
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_gmail_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A Gmail ConnectorDocument flows through prepare + index to a READY document."""
space_id = db_search_space.id
doc = _gmail_doc(
unique_id="msg-pipeline-1",
search_space_id=space_id,
connector_id=db_connector.id,
user_id=str(db_user.id),
)
service = IndexingPipelineService(session=db_session)
prepared = await service.prepare_for_indexing([doc])
assert len(prepared) == 1
await service.index(prepared[0], doc, llm=mocker.Mock())
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
row = result.scalars().first()
assert row is not None
assert row.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
assert row.source_markdown == doc.source_markdown
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_gmail_legacy_doc_migrated_then_reused(
db_session, db_search_space, db_connector, db_user, mocker
):
"""A legacy Composio Gmail doc is migrated then reused by the pipeline."""
space_id = db_search_space.id
user_id = str(db_user.id)
msg_id = "msg-legacy-gmail"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, msg_id, space_id
)
legacy_doc = Document(
title="Old Gmail",
document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR,
content="old summary",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
source_markdown="## Old content",
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
original_id = legacy_doc.id
connector_doc = _gmail_doc(
unique_id=msg_id,
search_space_id=space_id,
connector_id=db_connector.id,
user_id=user_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
prepared = await service.prepare_for_indexing([connector_doc])
assert len(prepared) == 1
assert prepared[0].id == original_id
assert prepared[0].document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR.value, msg_id, space_id
)
assert prepared[0].unique_identifier_hash == native_hash

View file

@ -0,0 +1,55 @@
"""Integration tests for IndexingPipelineService.index_batch()."""
import pytest
from sqlalchemy import select
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_index_batch_creates_ready_documents(
db_session, db_search_space, make_connector_document, mocker
):
"""index_batch prepares and indexes a batch, resulting in READY documents."""
space_id = db_search_space.id
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-batch-1",
search_space_id=space_id,
source_markdown="## Email 1\n\nBody",
),
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-batch-2",
search_space_id=space_id,
source_markdown="## Email 2\n\nDifferent body",
),
]
service = IndexingPipelineService(session=db_session)
results = await service.index_batch(docs, llm=mocker.Mock())
assert len(results) == 2
result = await db_session.execute(
select(Document).filter(Document.search_space_id == space_id)
)
rows = result.scalars().all()
assert len(rows) == 2
for row in rows:
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
assert row.content is not None
assert row.embedding is not None
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
async def test_index_batch_empty_returns_empty(db_session, mocker):
"""index_batch with empty input returns an empty list."""
service = IndexingPipelineService(session=db_session)
results = await service.index_batch([], llm=mocker.Mock())
assert results == []

View file

@ -0,0 +1,92 @@
"""Integration tests for IndexingPipelineService.migrate_legacy_docs()."""
import pytest
from sqlalchemy import select
from app.config import config as app_config
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
async def test_legacy_composio_gmail_doc_migrated_in_db(
db_session, db_search_space, db_user, make_connector_document
):
"""A Composio Gmail doc in the DB gets its hash and type updated to native."""
space_id = db_search_space.id
user_id = str(db_user.id)
unique_id = "msg-legacy-123"
legacy_hash = compute_identifier_hash(
DocumentType.COMPOSIO_GMAIL_CONNECTOR.value, unique_id, space_id
)
native_hash = compute_identifier_hash(
DocumentType.GOOGLE_GMAIL_CONNECTOR.value, unique_id, space_id
)
legacy_doc = Document(
title="Old Gmail",
document_type=DocumentType.COMPOSIO_GMAIL_CONNECTOR,
content="legacy content",
content_hash=f"ch-{legacy_hash[:12]}",
unique_identifier_hash=legacy_hash,
search_space_id=space_id,
created_by_id=user_id,
embedding=[0.1] * _EMBEDDING_DIM,
status={"state": "ready"},
)
db_session.add(legacy_doc)
await db_session.flush()
doc_id = legacy_doc.id
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=unique_id,
search_space_id=space_id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == doc_id))
reloaded = result.scalars().first()
assert reloaded.unique_identifier_hash == native_hash
assert reloaded.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
async def test_no_legacy_doc_is_noop(
db_session, db_search_space, make_connector_document
):
"""When no legacy document exists, migrate_legacy_docs does nothing."""
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
unique_id="evt-no-legacy",
search_space_id=db_search_space.id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(
select(Document).filter(Document.search_space_id == db_search_space.id)
)
assert result.scalars().all() == []
async def test_non_google_type_is_skipped(
db_session, db_search_space, make_connector_document
):
"""migrate_legacy_docs skips ConnectorDocuments that are not Google types."""
connector_doc = make_connector_document(
document_type=DocumentType.CLICKUP_CONNECTOR,
unique_id="task-1",
search_space_id=db_search_space.id,
)
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])

View file

@ -0,0 +1,34 @@
"""Pre-register the connector_indexers package to bypass a circular import
in its ``__init__.py`` (airtable_indexer -> routes -> connector_indexers).
This lets tests import individual indexer modules (e.g.
``google_drive_indexer``) without triggering the full package init.
"""
import sys
import types
from pathlib import Path
_BACKEND = Path(__file__).resolve().parents[3]
def _stub_package(dotted: str, fs_dir: Path) -> None:
if dotted not in sys.modules:
mod = types.ModuleType(dotted)
mod.__path__ = [str(fs_dir)]
mod.__package__ = dotted
sys.modules[dotted] = mod
parts = dotted.split(".")
if len(parts) > 1:
parent_dotted = ".".join(parts[:-1])
parent = sys.modules.get(parent_dotted)
if parent is not None:
setattr(parent, parts[-1], sys.modules[dotted])
_stub_package("app.tasks", _BACKEND / "app" / "tasks")
_stub_package(
"app.tasks.connector_indexers",
_BACKEND / "app" / "tasks" / "connector_indexers",
)

View file

@ -0,0 +1,373 @@
"""Tests for Confluence indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.confluence_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.confluence_indexer import (
_build_connector_doc,
index_confluence_pages,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_page(
page_id: str = "p1",
title: str = "Home",
space_id: str = "S1",
body: str = "<p>Hello</p>",
comments=None,
):
return {
"id": page_id,
"title": title,
"spaceId": space_id,
"body": {"storage": {"value": body}},
"comments": comments or [],
}
def _to_markdown(page: dict) -> str:
page_title = page.get("title", "")
page_content = page.get("body", {}).get("storage", {}).get("value", "")
comments = page.get("comments", [])
comments_content = ""
if comments:
comments_content = "\n\n## Comments\n\n"
for comment in comments:
comment_body = (
comment.get("body", {}).get("storage", {}).get("value", "")
)
comment_author = comment.get("version", {}).get("authorId", "Unknown")
comment_date = comment.get("version", {}).get("createdAt", "")
comments_content += (
f"**Comment by {comment_author}** ({comment_date}):\n"
f"{comment_body}\n\n"
)
return f"# {page_title}\n\n{page_content}{comments_content}"
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
page = _make_page(
page_id="abc-123",
title="Engineering Handbook",
space_id="ENG",
comments=[{"id": "c1"}],
)
markdown = _to_markdown(page)
doc = _build_connector_doc(
page,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "Engineering Handbook"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.CONFLUENCE_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["page_id"] == "abc-123"
assert doc.metadata["page_title"] == "Engineering Handbook"
assert doc.metadata["space_id"] == "ENG"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Confluence Page"
assert doc.metadata["connector_type"] == "Confluence"
assert doc.fallback_summary is not None
assert "Engineering Handbook" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
doc = _build_connector_doc(
_make_page(),
_to_markdown(_make_page()),
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_confluence_client(pages=None, error=None):
client = MagicMock()
client.get_pages_by_date_range = AsyncMock(
return_value=(pages if pages is not None else [], error),
)
client.close = AsyncMock()
return client
@pytest.fixture
def confluence_mocks(monkeypatch):
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
)
confluence_client = _mock_confluence_client(pages=[_make_page()])
monkeypatch.setattr(
_mod, "ConfluenceHistoryConnector", MagicMock(return_value=confluence_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"confluence_client": confluence_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_confluence_pages(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_page_calls_pipeline_and_returns_indexed_count(confluence_mocks):
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
confluence_mocks["batch_mock"].assert_called_once()
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.CONFLUENCE_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(confluence_mocks):
await _run_index(confluence_mocks)
call_kwargs = confluence_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(confluence_mocks):
await _run_index(confluence_mocks)
confluence_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Page skipping (missing id/title/content)
# ---------------------------------------------------------------------------
async def test_pages_with_missing_id_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid"),
_make_page(page_id="", title="Missing id"),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_pages_with_missing_title_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid"),
_make_page(page_id="p2", title=""),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_pages_with_no_content_are_skipped(confluence_mocks):
pages = [
_make_page(page_id="p1", title="Valid", body="<p>ok</p>"),
_make_page(page_id="p2", title="Empty", body=""),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_pages_are_skipped(confluence_mocks, monkeypatch):
pages = [
_make_page(page_id="p1", title="One"),
_make_page(page_id="p2", title="Two"),
]
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
pages,
None,
)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(confluence_mocks)
connector_docs = confluence_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(confluence_mocks):
heartbeat_cb = AsyncMock()
await _run_index(confluence_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = confluence_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty pages and no-data success tuple
# ---------------------------------------------------------------------------
async def test_empty_pages_returns_zero_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
None,
)
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
confluence_mocks["batch_mock"].assert_not_called()
async def test_no_pages_error_message_returns_success_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
"No pages found in date range",
)
indexed, skipped, warning = await _run_index(confluence_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
async def test_api_error_still_returns_3_tuple(confluence_mocks):
confluence_mocks["confluence_client"].get_pages_by_date_range.return_value = (
[],
"API exploded",
)
result = await _run_index(confluence_mocks)
assert len(result) == 3
assert result[0] == 0
assert result[1] == 0
assert "Failed to get Confluence pages" in result[2]
# ---------------------------------------------------------------------------
# Slice 7: Failed docs warning
# ---------------------------------------------------------------------------
async def test_failed_docs_warning_in_result(confluence_mocks):
confluence_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(confluence_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,671 @@
"""Tests for parallel download + indexing in the Google Drive indexer."""
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.tasks.connector_indexers.google_drive_indexer import (
_download_files_parallel,
_index_full_scan,
_index_selected_files,
_index_with_delta_sync,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_file_dict(file_id: str, name: str, mime: str = "text/plain") -> dict:
return {"id": file_id, "name": name, "mimeType": mime}
def _mock_extract_ok(file_id: str, file_name: str):
"""Return a successful (markdown, metadata, None) tuple."""
return (
f"# Content of {file_name}",
{"google_drive_file_id": file_id, "google_drive_file_name": file_name},
None,
)
@pytest.fixture
def mock_drive_client():
return MagicMock()
@pytest.fixture
def patch_extract(monkeypatch):
"""Provide a helper to set the download_and_extract_content mock."""
def _patch(side_effect=None, return_value=None):
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
mock,
)
return mock
return _patch
async def test_single_file_returns_one_connector_document(
mock_drive_client, patch_extract,
):
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
docs, failed = await _download_files_parallel(
mock_drive_client,
[_make_file_dict("f1", "test.txt")],
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 1
assert failed == 0
assert docs[0].title == "test.txt"
assert docs[0].unique_id == "f1"
async def test_multiple_files_all_produce_documents(
mock_drive_client, patch_extract,
):
"""All files are downloaded and converted to ConnectorDocuments."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
patch_extract(
side_effect=[_mock_extract_ok(f"f{i}", f"file{i}.txt") for i in range(3)]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 3
assert failed == 0
assert {d.unique_id for d in docs} == {"f0", "f1", "f2"}
async def test_one_download_exception_does_not_block_others(
mock_drive_client, patch_extract,
):
"""A RuntimeError in one download still lets the other files succeed."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
patch_extract(
side_effect=[
_mock_extract_ok("f0", "file0.txt"),
RuntimeError("network timeout"),
_mock_extract_ok("f2", "file2.txt"),
]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 2
assert failed == 1
assert {d.unique_id for d in docs} == {"f0", "f2"}
async def test_etl_error_counts_as_download_failure(
mock_drive_client, patch_extract,
):
"""download_and_extract_content returning an error is counted as failed."""
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
patch_extract(
side_effect=[
_mock_extract_ok("f0", "good.txt"),
(None, {}, "ETL failed"),
]
)
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert len(docs) == 1
assert failed == 1
async def test_concurrency_bounded_by_semaphore(
mock_drive_client, monkeypatch,
):
"""Peak concurrent downloads never exceeds max_concurrency."""
lock = asyncio.Lock()
active = 0
peak = 0
async def _slow_extract(client, file):
nonlocal active, peak
async with lock:
active += 1
peak = max(peak, active)
await asyncio.sleep(0.05)
async with lock:
active -= 1
fid = file["id"]
return _mock_extract_ok(fid, file["name"])
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
_slow_extract,
)
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(6)]
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
max_concurrency=2,
)
assert len(docs) == 6
assert failed == 0
assert peak <= 2, f"Peak concurrency was {peak}, expected <= 2"
async def test_heartbeat_fires_during_parallel_downloads(
mock_drive_client, monkeypatch,
):
"""on_heartbeat is called at least once when downloads take time."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
monkeypatch.setattr(_mod, "HEARTBEAT_INTERVAL_SECONDS", 0)
async def _slow_extract(client, file):
await asyncio.sleep(0.05)
return _mock_extract_ok(file["id"], file["name"])
monkeypatch.setattr(
"app.tasks.connector_indexers.google_drive_indexer.download_and_extract_content",
_slow_extract,
)
heartbeat_calls: list[int] = []
async def _on_heartbeat(count: int):
heartbeat_calls.append(count)
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
docs, failed = await _download_files_parallel(
mock_drive_client,
files,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
on_heartbeat=_on_heartbeat,
)
assert len(docs) == 3
assert failed == 0
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
# ---------------------------------------------------------------------------
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
# ---------------------------------------------------------------------------
def _folder_dict(file_id: str, name: str) -> dict:
return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"}
@pytest.fixture
def full_scan_mocks(mock_drive_client, monkeypatch):
"""Wire up all mocks needed to call _index_full_scan in isolation."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
mock_session = AsyncMock()
mock_connector = MagicMock()
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
mock_log_entry = MagicMock()
skip_results: dict[str, tuple[bool, str | None]] = {}
async def _fake_skip(session, file, search_space_id):
return skip_results.get(file["id"], (False, None))
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
download_mock = AsyncMock(return_value=([], 0))
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
batch_mock = AsyncMock(return_value=([], 0, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
)
return {
"drive_client": mock_drive_client,
"session": mock_session,
"connector": mock_connector,
"task_logger": mock_task_logger,
"log_entry": mock_log_entry,
"skip_results": skip_results,
"download_mock": download_mock,
"batch_mock": batch_mock,
"pipeline_mock": pipeline_mock,
}
async def _run_full_scan(mocks, *, max_files=500, include_subfolders=False):
return await _index_full_scan(
mocks["drive_client"],
mocks["session"],
mocks["connector"],
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"folder-root",
"My Folder",
mocks["task_logger"],
mocks["log_entry"],
max_files,
include_subfolders=include_subfolders,
enable_summary=True,
)
async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
"""Full scan collects files serially, downloads and indexes in parallel,
and returns correct (indexed, skipped) with renames counted as indexed."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [
_folder_dict("folder1", "SubFolder"),
_make_file_dict("skip1", "unchanged.txt"),
_make_file_dict("rename1", "renamed.txt"),
_make_file_dict("new1", "new1.txt"),
_make_file_dict("new2", "new2.txt"),
]
monkeypatch.setattr(
_mod, "get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old''renamed.txt'")
mock_docs = [MagicMock(), MagicMock()]
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
full_scan_mocks["batch_mock"].return_value = ([], 2, 0)
indexed, skipped = await _run_full_scan(full_scan_mocks)
assert indexed == 3 # 1 renamed + 2 from batch
assert skipped == 1 # 1 unchanged
full_scan_mocks["download_mock"].assert_called_once()
call_files = full_scan_mocks["download_mock"].call_args[0][1]
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"new1", "new2"}
async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
"""Only max_files non-folder files are processed; the rest are ignored."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
monkeypatch.setattr(
_mod, "get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
full_scan_mocks["download_mock"].return_value = ([], 0)
full_scan_mocks["batch_mock"].return_value = ([], 0, 0)
await _run_full_scan(full_scan_mocks, max_files=3)
download_call_files = full_scan_mocks["download_mock"].call_args[0][1]
assert len(download_call_files) == 3
async def test_full_scan_uses_max_concurrency_3_for_indexing(
full_scan_mocks, monkeypatch,
):
"""index_batch_parallel is called with max_concurrency=3."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [_make_file_dict("f1", "file1.txt")]
monkeypatch.setattr(
_mod, "get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
mock_docs = [MagicMock()]
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
full_scan_mocks["batch_mock"].return_value = ([], 1, 0)
await _run_full_scan(full_scan_mocks)
call_kwargs = full_scan_mocks["batch_mock"].call_args
assert call_kwargs[1].get("max_concurrency") == 3 or (
len(call_kwargs[0]) > 2 and call_kwargs[0][2] == 3
)
# ---------------------------------------------------------------------------
# Slice 7 -- _index_with_delta_sync three-phase pipeline
# ---------------------------------------------------------------------------
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
"""Removed/trashed changes call _remove_document; the rest go through
_download_files_parallel and index_batch_parallel."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
changes = [
{"fileId": "del1", "removed": True},
{"fileId": "del2", "file": {"id": "del2", "trashed": True}},
{"fileId": "trash1", "file": {"id": "trash1", "trashed": True}},
{"fileId": "mod1", "file": _make_file_dict("mod1", "modified1.txt")},
{"fileId": "mod2", "file": _make_file_dict("mod2", "modified2.txt")},
]
monkeypatch.setattr(
_mod, "fetch_all_changes",
AsyncMock(return_value=(changes, "new-token", None)),
)
change_types = {
"del1": "removed",
"del2": "removed",
"trash1": "trashed",
"mod1": "modified",
"mod2": "modified",
}
monkeypatch.setattr(
_mod, "categorize_change",
lambda change: change_types[change["fileId"]],
)
remove_calls: list[str] = []
async def _fake_remove(session, file_id, search_space_id):
remove_calls.append(file_id)
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
monkeypatch.setattr(
_mod, "_should_skip_file",
AsyncMock(return_value=(False, None)),
)
mock_docs = [MagicMock(), MagicMock()]
download_mock = AsyncMock(return_value=(mock_docs, 0))
monkeypatch.setattr(_mod, "_download_files_parallel", download_mock)
batch_mock = AsyncMock(return_value=([], 2, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
)
mock_session = AsyncMock()
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
indexed, skipped = await _index_with_delta_sync(
MagicMock(),
mock_session,
MagicMock(),
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"folder-root",
"start-token-abc",
mock_task_logger,
MagicMock(),
max_files=500,
enable_summary=True,
)
assert sorted(remove_calls) == ["del1", "del2", "trash1"]
download_mock.assert_called_once()
downloaded_files = download_mock.call_args[0][1]
assert len(downloaded_files) == 2
assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"}
assert indexed == 2
assert skipped == 0
# ---------------------------------------------------------------------------
# _index_selected_files -- parallel indexing of user-selected files
# ---------------------------------------------------------------------------
@pytest.fixture
def selected_files_mocks(mock_drive_client, monkeypatch):
"""Wire up mocks for _index_selected_files tests."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
mock_session = AsyncMock()
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
async def _fake_get_file(client, file_id):
return get_file_results.get(file_id, (None, f"Not configured: {file_id}"))
monkeypatch.setattr(_mod, "get_file_by_id", _fake_get_file)
skip_results: dict[str, tuple[bool, str | None]] = {}
async def _fake_skip(session, file, search_space_id):
return skip_results.get(file["id"], (False, None))
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
download_and_index_mock = AsyncMock(return_value=(0, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
return {
"drive_client": mock_drive_client,
"session": mock_session,
"get_file_results": get_file_results,
"skip_results": skip_results,
"download_and_index_mock": download_and_index_mock,
}
async def _run_selected(mocks, file_ids):
return await _index_selected_files(
mocks["drive_client"],
mocks["session"],
file_ids,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
async def test_selected_files_single_file_indexed(selected_files_mocks):
"""Tracer bullet: one file fetched, not skipped, indexed via parallel pipeline."""
selected_files_mocks["get_file_results"]["f1"] = (
_make_file_dict("f1", "report.pdf"),
None,
)
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks, [("f1", "report.pdf")],
)
assert indexed == 1
assert skipped == 0
assert errors == []
selected_files_mocks["download_and_index_mock"].assert_called_once()
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
"""get_file_by_id failing for one file collects an error; others still indexed."""
selected_files_mocks["get_file_results"]["f1"] = (
_make_file_dict("f1", "first.txt"), None,
)
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
selected_files_mocks["get_file_results"]["f3"] = (
_make_file_dict("f3", "third.txt"), None,
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")],
)
assert indexed == 2
assert skipped == 0
assert len(errors) == 1
assert "mid.txt" in errors[0]
assert "HTTP 404" in errors[0]
async def test_selected_files_skip_rename_counting(selected_files_mocks):
"""Unchanged files are skipped, renames counted as indexed,
and only new files are sent to _download_and_index."""
for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"),
("n1", "new1.txt"), ("n2", "new2.txt")]:
selected_files_mocks["get_file_results"][fid] = (
_make_file_dict(fid, fname), None,
)
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'")
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[("s1", "unchanged.txt"), ("r1", "renamed.txt"),
("n1", "new1.txt"), ("n2", "new2.txt")],
)
assert indexed == 3 # 1 renamed + 2 batch
assert skipped == 1 # 1 unchanged
assert errors == []
mock = selected_files_mocks["download_and_index_mock"]
mock.assert_called_once()
call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2]
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"n1", "n2"}
# ---------------------------------------------------------------------------
# asyncio.to_thread verification — prove blocking calls run in parallel
# ---------------------------------------------------------------------------
async def test_client_download_file_runs_in_thread_parallel():
"""Calling download_file concurrently via asyncio.gather should overlap
blocking work on separate threads, proving to_thread is effective.
Strategy: patch _sync_download_file with a blocking time.sleep(0.2).
Launch 3 concurrent calls. Serial would take >=0.6s; parallel < 0.4s.
"""
from app.connectors.google_drive.client import GoogleDriveClient
BLOCK_SECONDS = 0.2
NUM_CALLS = 3
def _blocking_download(service, file_id, credentials):
time.sleep(BLOCK_SECONDS)
return b"fake-content", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
client.service = MagicMock()
client._resolved_credentials = MagicMock()
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient, "_sync_download_file", staticmethod(_blocking_download),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.download_file(f"file-{i}") for i in range(NUM_CALLS))
)
elapsed = time.monotonic() - start
for content, error in results:
assert content == b"fake-content"
assert error is None
serial_minimum = BLOCK_SECONDS * NUM_CALLS
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"downloads are not running in parallel"
)
async def test_client_export_google_file_runs_in_thread_parallel():
"""Same strategy for export_google_file — verify to_thread parallelism."""
from app.connectors.google_drive.client import GoogleDriveClient
BLOCK_SECONDS = 0.2
NUM_CALLS = 3
def _blocking_export(service, file_id, mime_type, credentials):
time.sleep(BLOCK_SECONDS)
return b"exported", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
client.service = MagicMock()
client._resolved_credentials = MagicMock()
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient, "_sync_export_google_file", staticmethod(_blocking_export),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.export_google_file(f"file-{i}", "application/pdf")
for i in range(NUM_CALLS))
)
elapsed = time.monotonic() - start
for content, error in results:
assert content == b"exported"
assert error is None
serial_minimum = BLOCK_SECONDS * NUM_CALLS
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"exports are not running in parallel"
)

View file

@ -0,0 +1,372 @@
"""Tests for Jira indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.jira_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.jira_indexer import (
_build_connector_doc,
index_jira_issues,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_issue(
issue_key: str = "ENG-1",
issue_id: str = "10001",
title: str = "Fix login",
):
return {"key": issue_key, "id": issue_id, "title": title}
def _make_formatted_issue(
issue_key: str = "ENG-1",
issue_id: str = "10001",
title: str = "Fix login",
status: str = "In Progress",
priority: str = "High",
comments=None,
):
return {
"key": issue_key,
"id": issue_id,
"title": title,
"status": status,
"priority": priority,
"comments": comments or [],
}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
issue = _make_issue(issue_key="ENG-42", issue_id="4242", title="Fix auth bug")
formatted = _make_formatted_issue(
issue_key="ENG-42",
issue_id="4242",
title="Fix auth bug",
status="Done",
priority="Urgent",
comments=[{"id": "c1"}],
)
markdown = "# ENG-42: Fix auth bug\n\nBody"
doc = _build_connector_doc(
issue,
formatted,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "ENG-42: 4242"
assert doc.unique_id == "ENG-42"
assert doc.document_type == DocumentType.JIRA_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["issue_id"] == "ENG-42"
assert doc.metadata["issue_identifier"] == "ENG-42"
assert doc.metadata["issue_title"] == "4242"
assert doc.metadata["state"] == "Done"
assert doc.metadata["priority"] == "Urgent"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Jira Issue"
assert doc.metadata["connector_type"] == "Jira"
assert doc.fallback_summary is not None
assert "ENG-42" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
doc = _build_connector_doc(
_make_issue(),
_make_formatted_issue(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_jira_client(issues=None, error=None):
client = MagicMock()
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(
side_effect=lambda i: _make_formatted_issue(
issue_key=i.get("key", ""),
issue_id=i.get("id", ""),
title=i.get("title", ""),
)
)
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: f"# {fi.get('key', '')}: {fi.get('id', '')}\n\nContent"
)
client.close = AsyncMock()
return client
@pytest.fixture
def jira_mocks(monkeypatch):
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
)
jira_client = _mock_jira_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod, "JiraHistoryConnector", MagicMock(return_value=jira_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"jira_client": jira_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_jira_issues(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_issue_calls_pipeline_and_returns_indexed_count(jira_mocks):
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
jira_mocks["batch_mock"].assert_called_once()
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.JIRA_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(jira_mocks):
await _run_index(jira_mocks)
call_kwargs = jira_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(jira_mocks):
await _run_index(jira_mocks)
jira_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Issue skipping (missing key/title/content)
# ---------------------------------------------------------------------------
async def test_issues_with_missing_key_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
{"key": "", "id": "10002", "title": "No key"},
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_issues_with_missing_title_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
{"key": "ENG-2", "id": "", "title": "Missing id used as title"},
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
async def test_issues_with_no_content_are_skipped(jira_mocks):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
_make_issue(issue_key="ENG-2", issue_id="10002"),
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
jira_mocks["jira_client"].format_issue_to_markdown.side_effect = [
"# ENG-1: 10001\n\nContent",
"",
]
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_issues_are_skipped(jira_mocks, monkeypatch):
issues = [
_make_issue(issue_key="ENG-1", issue_id="10001"),
_make_issue(issue_key="ENG-2", issue_id="10002"),
]
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (issues, None)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(jira_mocks)
connector_docs = jira_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(jira_mocks):
heartbeat_cb = AsyncMock()
await _run_index(jira_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = jira_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty issues and no-data success tuple
# ---------------------------------------------------------------------------
async def test_empty_issues_returns_zero_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = ([], None)
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
jira_mocks["batch_mock"].assert_not_called()
async def test_no_issues_error_message_returns_success_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
[],
"No issues found in date range",
)
indexed, skipped, warning = await _run_index(jira_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
async def test_api_error_still_returns_3_tuple(jira_mocks):
jira_mocks["jira_client"].get_issues_by_date_range.return_value = (
[],
"API exploded",
)
result = await _run_index(jira_mocks)
assert len(result) == 3
assert result[0] == 0
assert result[1] == 0
assert "Failed to get Jira issues" in result[2]
# ---------------------------------------------------------------------------
# Slice 7: Failed docs warning
# ---------------------------------------------------------------------------
async def test_failed_docs_warning_in_result(jira_mocks):
jira_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(jira_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,355 @@
"""Tests for Linear indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.linear_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.linear_indexer import (
_build_connector_doc,
index_linear_issues,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_issue(
issue_id: str = "issue-1",
identifier: str = "ENG-1",
title: str = "Fix bug",
):
return {"id": issue_id, "identifier": identifier, "title": title}
def _make_formatted_issue(
issue_id: str = "issue-1",
identifier: str = "ENG-1",
title: str = "Fix bug",
state: str = "In Progress",
priority: str = "High",
comments=None,
):
return {
"id": issue_id,
"identifier": identifier,
"title": title,
"state": state,
"priority": priority,
"description": "Some description",
"comments": comments or [],
}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
"""Tracer bullet: a Linear issue produces a ConnectorDocument with correct fields."""
issue = _make_issue(issue_id="abc-123", identifier="ENG-42", title="Fix login bug")
formatted = _make_formatted_issue(
issue_id="abc-123",
identifier="ENG-42",
title="Fix login bug",
state="Done",
priority="Urgent",
comments=[{"id": "c1"}],
)
markdown = "# ENG-42: Fix login bug\n\nDescription here"
doc = _build_connector_doc(
issue,
formatted,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "ENG-42: Fix login bug"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.LINEAR_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["issue_id"] == "abc-123"
assert doc.metadata["issue_identifier"] == "ENG-42"
assert doc.metadata["issue_title"] == "Fix login bug"
assert doc.metadata["state"] == "Done"
assert doc.metadata["priority"] == "Urgent"
assert doc.metadata["comment_count"] == 1
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Linear Issue"
assert doc.metadata["connector_type"] == "Linear"
assert doc.fallback_summary is not None
assert "ENG-42" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
"""When enable_summary is False, should_summarize is False."""
doc = _build_connector_doc(
_make_issue(),
_make_formatted_issue(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-6
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_linear_client(issues=None, error=None):
client = MagicMock()
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(side_effect=lambda i: _make_formatted_issue(
issue_id=i.get("id", ""),
identifier=i.get("identifier", ""),
title=i.get("title", ""),
))
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
)
return client
@pytest.fixture
def linear_mocks(monkeypatch):
"""Wire up all external boundary mocks for index_linear_issues."""
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
)
linear_client = _mock_linear_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod, "LinearConnector", MagicMock(return_value=linear_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"linear_client": linear_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_linear_issues(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_issue_calls_pipeline_and_returns_indexed_count(linear_mocks):
"""One valid issue is passed to the pipeline and the indexed count is returned."""
indexed, skipped, warning = await _run_index(linear_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
linear_mocks["batch_mock"].assert_called_once()
call_args = linear_mocks["batch_mock"].call_args
connector_docs = call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.LINEAR_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(linear_mocks):
"""index_batch_parallel is called with max_concurrency=3."""
await _run_index(linear_mocks)
call_kwargs = linear_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(linear_mocks):
"""migrate_legacy_docs is called on the pipeline before index_batch_parallel."""
await _run_index(linear_mocks)
linear_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Issue skipping (missing ID / title)
# ---------------------------------------------------------------------------
async def test_issues_with_missing_id_are_skipped(linear_mocks):
"""Issues without id are skipped and not passed to the pipeline."""
issues = [
_make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"),
{"id": "", "identifier": "ENG-2", "title": "No ID"},
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].unique_id == "valid-1"
assert skipped == 1
async def test_issues_with_missing_title_are_skipped(linear_mocks):
"""Issues without title are skipped."""
issues = [
_make_issue(issue_id="valid-1", identifier="ENG-1", title="Valid"),
{"id": "id-2", "identifier": "ENG-2", "title": ""},
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch):
"""Issues whose content hash matches an existing document are skipped."""
issues = [
_make_issue(issue_id="new-1", identifier="ENG-1", title="New"),
_make_issue(issue_id="dup-1", identifier="ENG-2", title="Dup"),
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(linear_mocks):
"""on_heartbeat_callback is passed through to index_batch_parallel."""
heartbeat_cb = AsyncMock()
await _run_index(linear_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = linear_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Empty issues early return
# ---------------------------------------------------------------------------
async def test_empty_issues_returns_zero_tuple(linear_mocks):
"""When no issues are found, returns (0, 0, None) and pipeline is not called."""
linear_mocks["linear_client"].get_issues_by_date_range.return_value = ([], None)
indexed, skipped, warning = await _run_index(linear_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
linear_mocks["batch_mock"].assert_not_called()
async def test_failed_docs_warning_in_result(linear_mocks):
"""When documents fail indexing, the warning includes the count."""
linear_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(linear_mocks)
assert warning is not None
assert "2 failed" in warning

View file

@ -0,0 +1,345 @@
"""Tests for Notion indexer migrated to the unified parallel pipeline."""
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.tasks.connector_indexers.notion_indexer as _mod
from app.db import DocumentType
from app.tasks.connector_indexers.notion_indexer import (
_build_connector_doc,
index_notion_pages,
)
pytestmark = pytest.mark.unit
_USER_ID = "00000000-0000-0000-0000-000000000001"
_CONNECTOR_ID = 42
_SEARCH_SPACE_ID = 1
def _make_page(page_id: str = "page-1", title: str = "Test Page", content=None):
if content is None:
content = [{"type": "paragraph", "content": "Hello world", "children": []}]
return {"page_id": page_id, "title": title, "content": content}
# ---------------------------------------------------------------------------
# Slice 1: _build_connector_doc tracer bullet
# ---------------------------------------------------------------------------
async def test_build_connector_doc_produces_correct_fields():
"""Tracer bullet: a single Notion page produces a ConnectorDocument with correct fields."""
page = _make_page(page_id="abc-123", title="My Notion Page")
markdown = "# My Notion Page\n\nHello world"
doc = _build_connector_doc(
page,
markdown,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
assert doc.title == "My Notion Page"
assert doc.unique_id == "abc-123"
assert doc.document_type == DocumentType.NOTION_CONNECTOR
assert doc.source_markdown == markdown
assert doc.search_space_id == _SEARCH_SPACE_ID
assert doc.connector_id == _CONNECTOR_ID
assert doc.created_by_id == _USER_ID
assert doc.should_summarize is True
assert doc.metadata["page_title"] == "My Notion Page"
assert doc.metadata["page_id"] == "abc-123"
assert doc.metadata["connector_id"] == _CONNECTOR_ID
assert doc.metadata["document_type"] == "Notion Page"
assert doc.metadata["connector_type"] == "Notion"
assert doc.fallback_summary is not None
assert "My Notion Page" in doc.fallback_summary
assert markdown in doc.fallback_summary
async def test_build_connector_doc_summary_disabled():
"""When enable_summary is False, should_summarize is False."""
doc = _build_connector_doc(
_make_page(),
"# content",
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=False,
)
assert doc.should_summarize is False
# ---------------------------------------------------------------------------
# Shared fixtures for Slices 2-7 (full index_notion_pages tests)
# ---------------------------------------------------------------------------
def _mock_connector(enable_summary: bool = True):
c = MagicMock()
c.config = {"access_token": "tok"}
c.enable_summary = enable_summary
c.last_indexed_at = None
return c
def _mock_notion_client(pages=None, skipped_count=0, legacy_token=False):
client = MagicMock()
client.get_all_pages = AsyncMock(return_value=pages if pages is not None else [])
client.get_skipped_content_count = MagicMock(return_value=skipped_count)
client.is_using_legacy_token = MagicMock(return_value=legacy_token)
client.close = AsyncMock()
client.set_retry_callback = MagicMock()
return client
@pytest.fixture
def notion_mocks(monkeypatch):
"""Wire up all external boundary mocks for index_notion_pages."""
mock_session = AsyncMock()
mock_session.no_autoflush = MagicMock()
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
)
notion_client = _mock_notion_client(pages=[_make_page()])
monkeypatch.setattr(
_mod, "NotionHistoryConnector", MagicMock(return_value=notion_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
monkeypatch.setattr(
_mod, "process_blocks", MagicMock(return_value="Converted markdown content"),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
)
return {
"session": mock_session,
"connector": mock_connector,
"notion_client": notion_client,
"task_logger": mock_task_logger,
"pipeline_mock": pipeline_mock,
"batch_mock": batch_mock,
}
async def _run_index(mocks, **overrides):
return await index_notion_pages(
session=mocks["session"],
connector_id=overrides.get("connector_id", _CONNECTOR_ID),
search_space_id=overrides.get("search_space_id", _SEARCH_SPACE_ID),
user_id=overrides.get("user_id", _USER_ID),
start_date=overrides.get("start_date", "2025-01-01"),
end_date=overrides.get("end_date", "2025-12-31"),
update_last_indexed=overrides.get("update_last_indexed", True),
on_retry_callback=overrides.get("on_retry_callback"),
on_heartbeat_callback=overrides.get("on_heartbeat_callback"),
)
# ---------------------------------------------------------------------------
# Slice 2: Full pipeline wiring
# ---------------------------------------------------------------------------
async def test_one_page_calls_pipeline_and_returns_indexed_count(notion_mocks):
"""One valid page is passed to the pipeline and the indexed count is returned."""
indexed, skipped, warning = await _run_index(notion_mocks)
assert indexed == 1
assert skipped == 0
assert warning is None
notion_mocks["batch_mock"].assert_called_once()
call_args = notion_mocks["batch_mock"].call_args
connector_docs = call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].document_type == DocumentType.NOTION_CONNECTOR
async def test_pipeline_called_with_max_concurrency_3(notion_mocks):
"""index_batch_parallel is called with max_concurrency=3."""
await _run_index(notion_mocks)
call_kwargs = notion_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("max_concurrency") == 3
async def test_migrate_legacy_docs_called_before_indexing(notion_mocks):
"""migrate_legacy_docs is called on the pipeline before index_batch_parallel."""
await _run_index(notion_mocks)
notion_mocks["pipeline_mock"].migrate_legacy_docs.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 3: Page skipping (no content / missing ID)
# ---------------------------------------------------------------------------
async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch):
"""Pages without page_id are skipped and not passed to the pipeline."""
pages = [
_make_page(page_id="valid-1"),
{"title": "No ID page", "content": [{"type": "paragraph", "content": "text", "children": []}]},
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert connector_docs[0].unique_id == "valid-1"
assert skipped == 1
async def test_pages_with_no_content_are_skipped(notion_mocks, monkeypatch):
"""Pages with empty content are skipped."""
pages = [
_make_page(page_id="valid-1"),
_make_page(page_id="empty-1", content=[]),
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 4: Duplicate content skipping
# ---------------------------------------------------------------------------
async def test_duplicate_content_pages_are_skipped(notion_mocks, monkeypatch):
"""Pages whose content hash matches an existing document are skipped."""
pages = [
_make_page(page_id="new-1"),
_make_page(page_id="dup-1"),
]
notion_mocks["notion_client"].get_all_pages.return_value = pages
call_count = 0
async def _check_dup(session, content_hash):
nonlocal call_count
call_count += 1
if call_count == 2:
dup = MagicMock()
dup.id = 99
dup.document_type = "OTHER"
return dup
return None
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
_, skipped, _ = await _run_index(notion_mocks)
connector_docs = notion_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
assert skipped == 1
# ---------------------------------------------------------------------------
# Slice 5: Heartbeat callback forwarding
# ---------------------------------------------------------------------------
async def test_heartbeat_callback_forwarded_to_pipeline(notion_mocks):
"""on_heartbeat_callback is passed through to index_batch_parallel."""
heartbeat_cb = AsyncMock()
await _run_index(notion_mocks, on_heartbeat_callback=heartbeat_cb)
call_kwargs = notion_mocks["batch_mock"].call_args[1]
assert call_kwargs.get("on_heartbeat") is heartbeat_cb
# ---------------------------------------------------------------------------
# Slice 6: Notion-specific warning messages
# ---------------------------------------------------------------------------
async def test_skipped_ai_content_warning_in_result(notion_mocks):
"""When Notion AI content was skipped, the warning message includes it."""
notion_mocks["notion_client"].get_skipped_content_count.return_value = 3
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "API limitation" in warning
async def test_legacy_token_warning_in_result(notion_mocks):
"""When using legacy token, the warning message includes a notice."""
notion_mocks["notion_client"].is_using_legacy_token.return_value = True
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "legacy token" in warning.lower()
async def test_failed_docs_warning_in_result(notion_mocks):
"""When documents fail indexing, the warning includes the count."""
notion_mocks["batch_mock"].return_value = ([], 0, 2)
_, _, warning = await _run_index(notion_mocks)
assert warning is not None
assert "2 failed" in warning
# ---------------------------------------------------------------------------
# Slice 7: Empty pages early return
# ---------------------------------------------------------------------------
async def test_empty_pages_returns_zero_tuple(notion_mocks):
"""When no pages are found, returns (0, 0, None) and updates last_indexed."""
notion_mocks["notion_client"].get_all_pages.return_value = []
indexed, skipped, warning = await _run_index(notion_mocks)
assert indexed == 0
assert skipped == 0
assert warning is None
notion_mocks["batch_mock"].assert_not_called()

View file

@ -3,6 +3,7 @@ import pytest
from app.db import DocumentType from app.db import DocumentType
from app.indexing_pipeline.document_hashing import ( from app.indexing_pipeline.document_hashing import (
compute_content_hash, compute_content_hash,
compute_identifier_hash,
compute_unique_identifier_hash, compute_unique_identifier_hash,
) )
@ -61,3 +62,23 @@ def test_different_content_produces_different_content_hash(make_connector_docume
doc_a = make_connector_document(source_markdown="Original content") doc_a = make_connector_document(source_markdown="Original content")
doc_b = make_connector_document(source_markdown="Updated content") doc_b = make_connector_document(source_markdown="Updated content")
assert compute_content_hash(doc_a) != compute_content_hash(doc_b) assert compute_content_hash(doc_a) != compute_content_hash(doc_b)
def test_compute_identifier_hash_matches_connector_doc_hash(make_connector_document):
"""Raw-args hash equals ConnectorDocument hash for equivalent inputs."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-123",
search_space_id=5,
)
raw_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-123", 5)
assert raw_hash == compute_unique_identifier_hash(doc)
def test_compute_identifier_hash_differs_for_different_inputs():
"""Different arguments produce different hashes."""
h1 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 1)
h2 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-2", 1)
h3 = compute_identifier_hash("GOOGLE_DRIVE_FILE", "file-1", 2)
h4 = compute_identifier_hash("COMPOSIO_GOOGLE_DRIVE_CONNECTOR", "file-1", 1)
assert len({h1, h2, h3, h4}) == 4

View file

@ -0,0 +1,82 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
return AsyncMock()
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
async def test_calls_prepare_then_index_per_document(
pipeline, make_connector_document
):
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
doc1 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
doc2 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-2",
search_space_id=1,
)
orm1 = MagicMock(spec=Document)
orm1.unique_identifier_hash = compute_unique_identifier_hash(doc1)
orm2 = MagicMock(spec=Document)
orm2.unique_identifier_hash = compute_unique_identifier_hash(doc2)
mock_llm = MagicMock()
pipeline.prepare_for_indexing = AsyncMock(return_value=[orm1, orm2])
pipeline.index = AsyncMock(side_effect=lambda doc, cdoc, llm: doc)
results = await pipeline.index_batch([doc1, doc2], mock_llm)
pipeline.prepare_for_indexing.assert_awaited_once_with([doc1, doc2])
assert pipeline.index.await_count == 2
assert results == [orm1, orm2]
async def test_empty_input_returns_empty(pipeline):
"""Empty connector_docs list returns empty result."""
pipeline.prepare_for_indexing = AsyncMock(return_value=[])
results = await pipeline.index_batch([], MagicMock())
assert results == []
async def test_skips_document_without_matching_connector_doc(
pipeline, make_connector_document
):
"""If prepare returns a doc whose hash has no matching ConnectorDocument, it's skipped."""
doc1 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
orphan_orm = MagicMock(spec=Document)
orphan_orm.unique_identifier_hash = "nonexistent-hash"
pipeline.prepare_for_indexing = AsyncMock(return_value=[orphan_orm])
pipeline.index = AsyncMock()
results = await pipeline.index_batch([doc1], MagicMock())
pipeline.index.assert_not_awaited()
assert results == []

View file

@ -0,0 +1,186 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.config import config as app_config
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
_EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
session = AsyncMock()
session.refresh = AsyncMock()
return session
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
def _make_orm_doc(connector_doc, doc_id):
"""Create a MagicMock Document bound to a ConnectorDocument's hash."""
doc = MagicMock(spec=Document)
doc.id = doc_id
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
doc.status = DocumentStatus.pending()
return doc
async def test_index_calls_embed_and_chunk_via_to_thread(
pipeline, make_connector_document, monkeypatch
):
"""index() runs embed_texts and chunk_text via asyncio.to_thread, not blocking the loop."""
to_thread_calls = []
original_to_thread = asyncio.to_thread
async def tracking_to_thread(func, *args, **kwargs):
to_thread_calls.append(func.__name__)
return await original_to_thread(func, *args, **kwargs)
monkeypatch.setattr(asyncio, "to_thread", tracking_to_thread)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
AsyncMock(return_value="Summary."),
)
mock_chunk = MagicMock(return_value=["chunk1"])
mock_chunk.__name__ = "chunk_text"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
mock_chunk,
)
mock_embed = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
mock_embed.__name__ = "embed_texts"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
mock_embed,
)
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
)
document = MagicMock(spec=Document)
document.id = 1
document.status = DocumentStatus.pending()
await pipeline.index(document, connector_doc, llm=MagicMock())
assert "chunk_text" in to_thread_calls
assert "embed_texts" in to_thread_calls
def _mock_session_factory(orm_docs_by_id):
"""Replace get_celery_session_maker with a two-level callable.
get_celery_session_maker() -> session_maker
session_maker() -> async context manager yielding a mock session
"""
def _get_maker():
def _make_session():
session = MagicMock()
session.get = AsyncMock(
side_effect=lambda model, doc_id: orm_docs_by_id.get(doc_id)
)
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx
return _make_session
return _get_maker
async def test_batch_parallel_indexes_all_documents(
pipeline, make_connector_document, monkeypatch
):
"""index_batch_parallel indexes all documents and returns correct counts."""
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=f"msg-{i}",
search_space_id=1,
)
for i in range(3)
]
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
orm_by_id = {d.id: d for d in orm_docs}
monkeypatch.setattr(
"app.tasks.celery_tasks.get_celery_session_maker",
_mock_session_factory(orm_by_id),
)
index_calls = []
async def fake_index(self, document, connector_doc, llm):
index_calls.append(document.id)
document.status = DocumentStatus.ready()
return document
monkeypatch.setattr(IndexingPipelineService, "index", fake_index)
async def mock_get_llm(session):
return MagicMock()
_, indexed, failed = await pipeline.index_batch_parallel(
docs, mock_get_llm, max_concurrency=2
)
assert indexed == 3
assert failed == 0
assert sorted(index_calls) == [1, 2, 3]
async def test_batch_parallel_one_failure_does_not_affect_others(
pipeline, make_connector_document, monkeypatch
):
"""One document failure doesn't prevent other documents from being indexed."""
docs = [
make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=f"msg-{i}",
search_space_id=1,
)
for i in range(3)
]
orm_docs = [_make_orm_doc(cd, doc_id=i + 1) for i, cd in enumerate(docs)]
pipeline.prepare_for_indexing = AsyncMock(return_value=orm_docs)
orm_by_id = {d.id: d for d in orm_docs}
monkeypatch.setattr(
"app.tasks.celery_tasks.get_celery_session_maker",
_mock_session_factory(orm_by_id),
)
async def failing_index(self, document, connector_doc, llm):
if document.id == 2:
raise RuntimeError("LLM exploded")
document.status = DocumentStatus.ready()
return document
monkeypatch.setattr(IndexingPipelineService, "index", failing_index)
async def mock_get_llm(session):
return MagicMock()
_, indexed, failed = await pipeline.index_batch_parallel(
docs, mock_get_llm, max_concurrency=4
)
assert indexed == 2
assert failed == 1

View file

@ -0,0 +1,127 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
@pytest.fixture
def mock_session():
session = AsyncMock()
return session
@pytest.fixture
def pipeline(mock_session):
return IndexingPipelineService(mock_session)
def _make_execute_side_effect(doc_by_hash: dict):
"""Return a side_effect for session.execute that resolves documents by hash."""
async def _side_effect(stmt):
result = MagicMock()
for h, doc in doc_by_hash.items():
if h in str(stmt.compile(compile_kwargs={"literal_binds": True})):
result.scalars.return_value.first.return_value = doc
return result
result.scalars.return_value.first.return_value = None
return result
return _side_effect
async def test_updates_hash_and_type_for_legacy_document(
pipeline, mock_session, make_connector_document
):
"""Legacy Composio document gets unique_identifier_hash and document_type updated."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-abc",
search_space_id=1,
)
legacy_hash = compute_identifier_hash("COMPOSIO_GMAIL_CONNECTOR", "msg-abc", 1)
native_hash = compute_identifier_hash("GOOGLE_GMAIL_CONNECTOR", "msg-abc", 1)
existing = MagicMock(spec=Document)
existing.unique_identifier_hash = legacy_hash
existing.document_type = DocumentType.COMPOSIO_GMAIL_CONNECTOR
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = existing
mock_session.execute = AsyncMock(return_value=result_mock)
await pipeline.migrate_legacy_docs([doc])
assert existing.unique_identifier_hash == native_hash
assert existing.document_type == DocumentType.GOOGLE_GMAIL_CONNECTOR
mock_session.commit.assert_awaited_once()
async def test_noop_when_no_legacy_document_exists(
pipeline, mock_session, make_connector_document
):
"""No updates when no legacy Composio document is found in DB."""
doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-xyz",
search_space_id=1,
)
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=result_mock)
await pipeline.migrate_legacy_docs([doc])
mock_session.commit.assert_awaited_once()
async def test_skips_non_google_doc_types(
pipeline, mock_session, make_connector_document
):
"""Non-Google doc types have no legacy mapping and trigger no DB query."""
doc = make_connector_document(
document_type=DocumentType.SLACK_CONNECTOR,
unique_id="slack-123",
search_space_id=1,
)
await pipeline.migrate_legacy_docs([doc])
mock_session.execute.assert_not_awaited()
mock_session.commit.assert_awaited_once()
async def test_handles_all_three_google_types(
pipeline, mock_session, make_connector_document
):
"""Each native Google type correctly maps to its Composio legacy type."""
mappings = [
(DocumentType.GOOGLE_GMAIL_CONNECTOR, "COMPOSIO_GMAIL_CONNECTOR"),
(DocumentType.GOOGLE_CALENDAR_CONNECTOR, "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"),
(DocumentType.GOOGLE_DRIVE_FILE, "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"),
]
for native_type, expected_legacy in mappings:
doc = make_connector_document(
document_type=native_type,
unique_id="id-1",
search_space_id=1,
)
existing = MagicMock(spec=Document)
existing.document_type = DocumentType(expected_legacy)
result_mock = MagicMock()
result_mock.scalars.return_value.first.return_value = existing
mock_session.execute = AsyncMock(return_value=result_mock)
mock_session.commit = AsyncMock()
await pipeline.migrate_legacy_docs([doc])
assert existing.document_type == native_type

8862
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,14 @@
import { loader } from "fumadocs-core/source"; import { loader } from "fumadocs-core/source";
import type { Metadata } from "next";
import { changelog } from "@/.source/server"; import { changelog } from "@/.source/server";
import { formatDate } from "@/lib/utils"; import { formatDate } from "@/lib/utils";
import { getMDXComponents } from "@/mdx-components"; import { getMDXComponents } from "@/mdx-components";
export const metadata: Metadata = {
title: "Changelog | SurfSense",
description: "See what's new in SurfSense.",
};
const source = loader({ const source = loader({
baseUrl: "/changelog", baseUrl: "/changelog",
source: changelog.toFumadocsSource(), source: changelog.toFumadocsSource(),

View file

@ -1,6 +1,11 @@
import React from "react"; import type { Metadata } from "next";
import { ContactFormGridWithDetails } from "@/components/contact/contact-form"; import { ContactFormGridWithDetails } from "@/components/contact/contact-form";
export const metadata: Metadata = {
title: "Contact | SurfSense",
description: "Get in touch with the SurfSense team.",
};
const page = () => { const page = () => {
return ( return (
<div> <div>

View file

@ -1,6 +1,11 @@
import React from "react"; import type { Metadata } from "next";
import PricingBasic from "@/components/pricing/pricing-section"; import PricingBasic from "@/components/pricing/pricing-section";
export const metadata: Metadata = {
title: "Pricing | SurfSense",
description: "Explore SurfSense plans and pricing options.",
};
const page = () => { const page = () => {
return ( return (
<div> <div>

View file

@ -473,14 +473,14 @@ export function DocumentsTableShell({
}, [deletableSelectedIds, bulkDeleteDocuments, deleteDocument]); }, [deletableSelectedIds, bulkDeleteDocuments, deleteDocument]);
const bulkDeleteBar = hasDeletableSelection ? ( const bulkDeleteBar = hasDeletableSelection ? (
<div className="flex items-center justify-center py-1.5 border-b border-border/50 bg-destructive/5 shrink-0 animate-in fade-in slide-in-from-top-1 duration-150"> <div className="absolute inset-x-0 top-0 z-10 flex items-center justify-center py-1 pointer-events-none animate-in fade-in duration-150">
<button <button
type="button" type="button"
onClick={() => setBulkDeleteConfirmOpen(true)} onClick={() => setBulkDeleteConfirmOpen(true)}
className="flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-sm text-xs font-medium hover:bg-destructive/90 transition-colors" className="pointer-events-auto flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-lg text-xs font-medium hover:bg-destructive/90 transition-colors"
> >
<Trash2 size={12} /> <Trash2 size={12} />
Delete ({deletableSelectedIds.length} selected) Delete {deletableSelectedIds.length} {deletableSelectedIds.length === 1 ? "item" : "items"}
</button> </button>
</div> </div>
) : null; ) : null;
@ -526,7 +526,6 @@ export function DocumentsTableShell({
</TableRow> </TableRow>
</TableHeader> </TableHeader>
</Table> </Table>
{bulkDeleteBar}
{loading ? ( {loading ? (
<div className="flex-1 overflow-auto"> <div className="flex-1 overflow-auto">
<Table className="table-fixed w-full"> <Table className="table-fixed w-full">
@ -594,7 +593,8 @@ export function DocumentsTableShell({
)} )}
</div> </div>
) : ( ) : (
<div ref={desktopScrollRef} className="flex-1 overflow-auto"> <div ref={desktopScrollRef} className="flex-1 overflow-auto relative">
{bulkDeleteBar}
<Table className="table-fixed w-full"> <Table className="table-fixed w-full">
<TableBody> <TableBody>
{sorted.map((doc) => { {sorted.map((doc) => {
@ -788,9 +788,6 @@ export function DocumentsTableShell({
)} )}
</div> </div>
{/* Mobile bulk delete bar */}
<div className="md:hidden">{bulkDeleteBar}</div>
{/* Mobile Card View */} {/* Mobile Card View */}
{loading ? ( {loading ? (
<div className="md:hidden divide-y divide-border/50 flex-1 overflow-auto"> <div className="md:hidden divide-y divide-border/50 flex-1 overflow-auto">
@ -846,8 +843,9 @@ export function DocumentsTableShell({
) : ( ) : (
<div <div
ref={mobileScrollRef} ref={mobileScrollRef}
className="md:hidden divide-y divide-border/50 flex-1 overflow-auto" className="md:hidden divide-y divide-border/50 flex-1 overflow-auto relative"
> >
{bulkDeleteBar}
{sorted.map((doc) => { {sorted.map((doc) => {
const isMentioned = mentionedDocIds?.has(doc.id) ?? false; const isMentioned = mentionedDocIds?.has(doc.id) ?? false;
const statusState = doc.status?.state ?? "ready"; const statusState = doc.status?.state ?? "ready";

View file

@ -595,6 +595,7 @@ function CreateInviteDialog({
}); });
} catch (error) { } catch (error) {
console.error("Failed to create invite:", error); console.error("Failed to create invite:", error);
toast.error("Failed to create invite. Please try again.");
} finally { } finally {
setCreating(false); setCreating(false);
} }

View file

@ -29,9 +29,6 @@ export const createDocumentMutationAtom = atomWithMutation((get) => {
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: cacheKeys.documents.globalQueryParams(documentsQueryParams), queryKey: cacheKeys.documents.globalQueryParams(documentsQueryParams),
}); });
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
});
}, },
}; };
}); });
@ -75,9 +72,6 @@ export const updateDocumentMutationAtom = atomWithMutation((get) => {
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: cacheKeys.documents.document(String(request.id)), queryKey: cacheKeys.documents.document(String(request.id)),
}); });
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
});
}, },
}; };
}); });
@ -109,9 +103,6 @@ export const deleteDocumentMutationAtom = atomWithMutation((get) => {
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: cacheKeys.documents.document(String(request.id)), queryKey: cacheKeys.documents.document(String(request.id)),
}); });
queryClient.invalidateQueries({
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
});
}, },
}; };
}); });

View file

@ -1,38 +0,0 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import type { SearchDocumentsRequest } from "@/contracts/types/document.types";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { globalDocumentsQueryParamsAtom } from "./ui.atoms";
export const documentsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const queryParams = get(globalDocumentsQueryParamsAtom);
return {
queryKey: cacheKeys.documents.globalQueryParams(queryParams),
enabled: !!searchSpaceId,
queryFn: async () => {
return documentsApiService.getDocuments({
queryParams: queryParams,
});
},
};
});
export const documentTypeCountsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
queryKey: cacheKeys.documents.typeCounts(searchSpaceId ?? undefined),
enabled: !!searchSpaceId,
staleTime: 10 * 60 * 1000, // 10 minutes
queryFn: async () => {
return documentsApiService.getDocumentTypeCounts({
queryParams: {
search_space_id: searchSpaceId ?? undefined,
},
});
},
};
});

View file

@ -4,7 +4,7 @@ import { useAtomValue, useSetAtom } from "jotai";
import { AlertTriangle, Cable, Settings } from "lucide-react"; import { AlertTriangle, Cable, Settings } from "lucide-react";
import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react"; import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react";
import { createPortal } from "react-dom"; import { createPortal } from "react-dom";
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms"; import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
import { statusInboxItemsAtom } from "@/atoms/inbox/status-inbox.atom"; import { statusInboxItemsAtom } from "@/atoms/inbox/status-inbox.atom";
import { import {
globalNewLLMConfigsAtom, globalNewLLMConfigsAtom,
@ -72,9 +72,9 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
const llmConfigLoading = preferencesLoading || globalConfigsLoading; const llmConfigLoading = preferencesLoading || globalConfigsLoading;
// Fetch document type counts via the lightweight /type-counts endpoint (cached 10 min) // Real-time document type counts via Zero (updates instantly as docs are indexed)
const { data: documentTypeCounts, isFetching: documentTypesLoading } = const documentTypeCounts = useZeroDocumentTypeCounts(searchSpaceId);
useAtomValue(documentTypeCountsAtom); const documentTypesLoading = documentTypeCounts === undefined;
// Read status inbox items from shared atom (populated by LayoutDataProvider) // Read status inbox items from shared atom (populated by LayoutDataProvider)
// instead of creating a duplicate useInbox("status") hook. // instead of creating a duplicate useInbox("status") hook.

View file

@ -867,6 +867,9 @@ export const useConnectorDialog = () => {
setIsOpen(false); setIsOpen(false);
setIsFromOAuth(false); setIsFromOAuth(false);
setIndexingConfig(null);
setIndexingConnector(null);
setIndexingConnectorConfig(null);
refreshConnectors(); refreshConnectors();
queryClient.invalidateQueries({ queryClient.invalidateQueries({
@ -898,6 +901,9 @@ export const useConnectorDialog = () => {
const handleSkipIndexing = useCallback(() => { const handleSkipIndexing = useCallback(() => {
setIsOpen(false); setIsOpen(false);
setIsFromOAuth(false); setIsFromOAuth(false);
setIndexingConfig(null);
setIndexingConnector(null);
setIndexingConnectorConfig(null);
}, [setIsOpen]); }, [setIsOpen]);
// Handle starting edit mode // Handle starting edit mode

View file

@ -1,6 +1,5 @@
"use client"; "use client";
import { FolderPlus } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { import {
@ -52,22 +51,25 @@ export function CreateFolderDialog({
return ( return (
<Dialog open={open} onOpenChange={onOpenChange}> <Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-sm"> <DialogContent className="select-none max-w-[90vw] sm:max-w-sm p-4 sm:p-5 data-[state=open]:animate-none data-[state=closed]:animate-none">
<DialogHeader> <DialogHeader className="space-y-2 pb-2">
<DialogTitle className="flex items-center gap-2"> <div className="flex items-center gap-2 sm:gap-3">
<FolderPlus className="size-5 text-muted-foreground" /> <div className="flex-1 min-w-0">
<DialogTitle className="text-base sm:text-lg">
{isSubfolder ? "New subfolder" : "New folder"} {isSubfolder ? "New subfolder" : "New folder"}
</DialogTitle> </DialogTitle>
<DialogDescription> <DialogDescription className="text-xs sm:text-sm mt-0.5">
{isSubfolder {isSubfolder
? `Create a new folder inside "${parentFolderName}".` ? `Create a new folder inside "${parentFolderName}".`
: "Create a new folder at the root level."} : "Create a new folder at the root level."}
</DialogDescription> </DialogDescription>
</div>
</div>
</DialogHeader> </DialogHeader>
<form onSubmit={handleSubmit} className="flex flex-col gap-4"> <form onSubmit={handleSubmit} className="flex flex-col gap-3 sm:gap-4">
<div className="flex flex-col gap-2"> <div className="flex flex-col gap-2">
<Label htmlFor="folder-name">Folder name</Label> <Label htmlFor="folder-name" className="text-sm">Folder name</Label>
<Input <Input
ref={inputRef} ref={inputRef}
id="folder-name" id="folder-name"
@ -76,14 +78,24 @@ export function CreateFolderDialog({
onChange={(e) => setName(e.target.value)} onChange={(e) => setName(e.target.value)}
maxLength={255} maxLength={255}
autoComplete="off" autoComplete="off"
className="text-sm h-9 sm:h-10"
/> />
</div> </div>
<DialogFooter> <DialogFooter className="flex-row justify-end gap-2 pt-2 sm:pt-3">
<Button type="button" variant="outline" onClick={() => onOpenChange(false)}> <Button
type="button"
variant="secondary"
onClick={() => onOpenChange(false)}
className="h-8 sm:h-9 text-xs sm:text-sm"
>
Cancel Cancel
</Button> </Button>
<Button type="submit" disabled={!name.trim()}> <Button
type="submit"
disabled={!name.trim()}
className="h-8 sm:h-9 text-xs sm:text-sm"
>
Create Create
</Button> </Button>
</DialogFooter> </DialogFooter>

View file

@ -1,25 +1,32 @@
"use client"; "use client";
import { Eye, MoreHorizontal, Move, Pencil, Trash2 } from "lucide-react"; import { AlertCircle, Clock, Download, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react";
import React, { useCallback } from "react"; import React, { useCallback, useRef, useState } from "react";
import { useDrag } from "react-dnd"; import { useDrag } from "react-dnd";
import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon"; import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
import { ExportContextItems, ExportDropdownItems } from "@/components/shared/ExportMenuItems";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox"; import { Checkbox } from "@/components/ui/checkbox";
import { import {
ContextMenu, ContextMenu,
ContextMenuContent, ContextMenuContent,
ContextMenuItem, ContextMenuItem,
ContextMenuSeparator, ContextMenuSub,
ContextMenuSubContent,
ContextMenuSubTrigger,
ContextMenuTrigger, ContextMenuTrigger,
} from "@/components/ui/context-menu"; } from "@/components/ui/context-menu";
import { import {
DropdownMenu, DropdownMenu,
DropdownMenuContent, DropdownMenuContent,
DropdownMenuItem, DropdownMenuItem,
DropdownMenuSeparator, DropdownMenuSub,
DropdownMenuSubContent,
DropdownMenuSubTrigger,
DropdownMenuTrigger, DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"; } from "@/components/ui/dropdown-menu";
import { Spinner } from "@/components/ui/spinner";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { DND_TYPES } from "./FolderNode"; import { DND_TYPES } from "./FolderNode";
@ -41,6 +48,9 @@ interface DocumentNodeProps {
onEdit: (doc: DocumentNodeDoc) => void; onEdit: (doc: DocumentNodeDoc) => void;
onDelete: (doc: DocumentNodeDoc) => void; onDelete: (doc: DocumentNodeDoc) => void;
onMove: (doc: DocumentNodeDoc) => void; onMove: (doc: DocumentNodeDoc) => void;
onExport?: (doc: DocumentNodeDoc, format: string) => void;
contextMenuOpen?: boolean;
onContextMenuOpenChange?: (open: boolean) => void;
} }
export const DocumentNode = React.memo(function DocumentNode({ export const DocumentNode = React.memo(function DocumentNode({
@ -52,6 +62,9 @@ export const DocumentNode = React.memo(function DocumentNode({
onEdit, onEdit,
onDelete, onDelete,
onMove, onMove,
onExport,
contextMenuOpen,
onContextMenuOpenChange,
}: DocumentNodeProps) { }: DocumentNodeProps) {
const statusState = doc.status?.state ?? "ready"; const statusState = doc.status?.state ?? "ready";
const isSelectable = statusState !== "pending" && statusState !== "processing"; const isSelectable = statusState !== "pending" && statusState !== "processing";
@ -74,48 +87,90 @@ export const DocumentNode = React.memo(function DocumentNode({
); );
const isProcessing = statusState === "pending" || statusState === "processing"; const isProcessing = statusState === "pending" || statusState === "processing";
const [dropdownOpen, setDropdownOpen] = useState(false);
const [exporting, setExporting] = useState<string | null>(null);
const rowRef = useRef<HTMLButtonElement>(null);
const handleExport = useCallback(
(format: string) => {
if (!onExport) return;
setExporting(format);
onExport(doc, format);
setTimeout(() => setExporting(null), 2000);
},
[doc, onExport]
);
const attachRef = useCallback(
(node: HTMLButtonElement | null) => {
(rowRef as React.MutableRefObject<HTMLButtonElement | null>).current = node;
drag(node);
},
[drag]
);
return ( return (
<ContextMenu> <ContextMenu onOpenChange={onContextMenuOpenChange}>
<ContextMenuTrigger asChild> <ContextMenuTrigger asChild>
{/* biome-ignore lint/a11y/useSemanticElements: div required for drag ref */} <button
<div type="button"
ref={drag} ref={attachRef}
role="button"
tabIndex={0}
className={cn( className={cn(
"group flex h-8 items-center gap-1.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none", "group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left",
isMentioned && "bg-accent/30", isMentioned && "bg-accent/30",
isDragging && "opacity-40" isDragging && "opacity-40"
)} )}
style={{ paddingLeft: `${depth * 16 + 4}px` }} style={{ paddingLeft: `${depth * 16 + 4}px` }}
onClick={handleCheckChange} onClick={handleCheckChange}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
handleCheckChange();
}
}}
> >
{isSelectable ? ( {(() => {
if (statusState === "pending") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<Clock className="h-3.5 w-3.5 text-muted-foreground/60" />
</span>
</TooltipTrigger>
<TooltipContent side="top">Pending - waiting to be synced</TooltipContent>
</Tooltip>
);
}
if (statusState === "processing") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<Spinner size="xs" className="text-primary" />
</span>
</TooltipTrigger>
<TooltipContent side="top">Syncing</TooltipContent>
</Tooltip>
);
}
if (statusState === "failed") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<AlertCircle className="h-3.5 w-3.5 text-destructive" />
</span>
</TooltipTrigger>
<TooltipContent side="top" className="max-w-xs">
{doc.status?.reason || "Processing failed"}
</TooltipContent>
</Tooltip>
);
}
return (
<Checkbox <Checkbox
checked={isMentioned} checked={isMentioned}
onCheckedChange={handleCheckChange} onCheckedChange={handleCheckChange}
onClick={(e) => e.stopPropagation()} onClick={(e) => e.stopPropagation()}
className="h-3.5 w-3.5 shrink-0" className="h-3.5 w-3.5 shrink-0"
/> />
) : ( );
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center"> })()}
<span
className={cn(
"h-2 w-2 rounded-full",
statusState === "processing" && "animate-pulse bg-amber-500",
statusState === "pending" && "bg-muted-foreground/40",
statusState === "failed" && "bg-destructive"
)}
/>
</span>
)}
<span className="flex-1 min-w-0 truncate">{doc.title}</span> <span className="flex-1 min-w-0 truncate">{doc.title}</span>
@ -126,25 +181,28 @@ export const DocumentNode = React.memo(function DocumentNode({
)} )}
</span> </span>
<DropdownMenu> <DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
<DropdownMenuTrigger asChild> <DropdownMenuTrigger asChild>
<Button <Button
variant="ghost" variant="ghost"
size="icon" size="icon"
className="h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity" className={cn(
"hidden sm:inline-flex h-6 w-6 shrink-0 hover:bg-transparent",
dropdownOpen ? "opacity-100 bg-accent hover:bg-accent" : "opacity-0 group-hover:opacity-100"
)}
onClick={(e) => e.stopPropagation()} onClick={(e) => e.stopPropagation()}
> >
<MoreHorizontal className="h-3.5 w-3.5" /> <MoreHorizontal className="h-3.5 w-3.5" />
</Button> </Button>
</DropdownMenuTrigger> </DropdownMenuTrigger>
<DropdownMenuContent align="end" className="w-44"> <DropdownMenuContent align="end" className="w-40">
<DropdownMenuItem onClick={() => onPreview(doc)}> <DropdownMenuItem onClick={() => onPreview(doc)}>
<Eye className="mr-2 h-4 w-4" /> <Eye className="mr-2 h-4 w-4" />
Open Open
</DropdownMenuItem> </DropdownMenuItem>
{isEditable && ( {isEditable && (
<DropdownMenuItem onClick={() => onEdit(doc)}> <DropdownMenuItem onClick={() => onEdit(doc)}>
<Pencil className="mr-2 h-4 w-4" /> <PenLine className="mr-2 h-4 w-4" />
Edit Edit
</DropdownMenuItem> </DropdownMenuItem>
)} )}
@ -152,7 +210,17 @@ export const DocumentNode = React.memo(function DocumentNode({
<Move className="mr-2 h-4 w-4" /> <Move className="mr-2 h-4 w-4" />
Move to... Move to...
</DropdownMenuItem> </DropdownMenuItem>
<DropdownMenuSeparator /> {onExport && (
<DropdownMenuSub>
<DropdownMenuSubTrigger>
<Download className="mr-2 h-4 w-4" />
Export
</DropdownMenuSubTrigger>
<DropdownMenuSubContent className="min-w-[180px]">
<ExportDropdownItems onExport={handleExport} exporting={exporting} />
</DropdownMenuSubContent>
</DropdownMenuSub>
)}
<DropdownMenuItem <DropdownMenuItem
className="text-destructive focus:text-destructive" className="text-destructive focus:text-destructive"
disabled={isProcessing} disabled={isProcessing}
@ -163,17 +231,18 @@ export const DocumentNode = React.memo(function DocumentNode({
</DropdownMenuItem> </DropdownMenuItem>
</DropdownMenuContent> </DropdownMenuContent>
</DropdownMenu> </DropdownMenu>
</div> </button>
</ContextMenuTrigger> </ContextMenuTrigger>
<ContextMenuContent className="w-44"> {contextMenuOpen && (
<ContextMenuContent className="w-40">
<ContextMenuItem onClick={() => onPreview(doc)}> <ContextMenuItem onClick={() => onPreview(doc)}>
<Eye className="mr-2 h-4 w-4" /> <Eye className="mr-2 h-4 w-4" />
Open Open
</ContextMenuItem> </ContextMenuItem>
{isEditable && ( {isEditable && (
<ContextMenuItem onClick={() => onEdit(doc)}> <ContextMenuItem onClick={() => onEdit(doc)}>
<Pencil className="mr-2 h-4 w-4" /> <PenLine className="mr-2 h-4 w-4" />
Edit Edit
</ContextMenuItem> </ContextMenuItem>
)} )}
@ -181,7 +250,17 @@ export const DocumentNode = React.memo(function DocumentNode({
<Move className="mr-2 h-4 w-4" /> <Move className="mr-2 h-4 w-4" />
Move to... Move to...
</ContextMenuItem> </ContextMenuItem>
<ContextMenuSeparator /> {onExport && (
<ContextMenuSub>
<ContextMenuSubTrigger>
<Download className="mr-2 h-4 w-4" />
Export
</ContextMenuSubTrigger>
<ContextMenuSubContent className="min-w-[180px]">
<ExportContextItems onExport={handleExport} exporting={exporting} />
</ContextMenuSubContent>
</ContextMenuSub>
)}
<ContextMenuItem <ContextMenuItem
className="text-destructive focus:text-destructive" className="text-destructive focus:text-destructive"
disabled={isProcessing} disabled={isProcessing}
@ -191,6 +270,7 @@ export const DocumentNode = React.memo(function DocumentNode({
Delete Delete
</ContextMenuItem> </ContextMenuItem>
</ContextMenuContent> </ContextMenuContent>
)}
</ContextMenu> </ContextMenu>
); );
}); });

View file

@ -8,7 +8,7 @@ import {
FolderPlus, FolderPlus,
MoreHorizontal, MoreHorizontal,
Move, Move,
Pencil, PenLine,
Trash2, Trash2,
} from "lucide-react"; } from "lucide-react";
import React, { useCallback, useEffect, useRef, useState } from "react"; import React, { useCallback, useEffect, useRef, useState } from "react";
@ -18,14 +18,12 @@ import {
ContextMenu, ContextMenu,
ContextMenuContent, ContextMenuContent,
ContextMenuItem, ContextMenuItem,
ContextMenuSeparator,
ContextMenuTrigger, ContextMenuTrigger,
} from "@/components/ui/context-menu"; } from "@/components/ui/context-menu";
import { import {
DropdownMenu, DropdownMenu,
DropdownMenuContent, DropdownMenuContent,
DropdownMenuItem, DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger, DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"; } from "@/components/ui/dropdown-menu";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
@ -66,6 +64,8 @@ interface FolderNodeProps {
onReorderFolder?: (folderId: number, beforePos: string | null, afterPos: string | null) => void; onReorderFolder?: (folderId: number, beforePos: string | null, afterPos: string | null) => void;
siblingPositions?: { before: string | null; after: string | null }; siblingPositions?: { before: string | null; after: string | null };
disabledDropIds?: Set<number>; disabledDropIds?: Set<number>;
contextMenuOpen?: boolean;
onContextMenuOpenChange?: (open: boolean) => void;
} }
function getDropZone( function getDropZone(
@ -99,6 +99,8 @@ export const FolderNode = React.memo(function FolderNode({
onReorderFolder, onReorderFolder,
siblingPositions, siblingPositions,
disabledDropIds, disabledDropIds,
contextMenuOpen,
onContextMenuOpenChange,
}: FolderNodeProps) { }: FolderNodeProps) {
const [renameValue, setRenameValue] = useState(folder.name); const [renameValue, setRenameValue] = useState(folder.name);
const inputRef = useRef<HTMLInputElement>(null); const inputRef = useRef<HTMLInputElement>(null);
@ -213,7 +215,7 @@ export const FolderNode = React.memo(function FolderNode({
const FolderIcon = isExpanded ? FolderOpen : Folder; const FolderIcon = isExpanded ? FolderOpen : Folder;
return ( return (
<ContextMenu> <ContextMenu onOpenChange={onContextMenuOpenChange}>
<ContextMenuTrigger asChild disabled={isRenaming}> <ContextMenuTrigger asChild disabled={isRenaming}>
{/* biome-ignore lint/a11y/useSemanticElements: div required for drag/drop refs */} {/* biome-ignore lint/a11y/useSemanticElements: div required for drag/drop refs */}
<div <div
@ -261,7 +263,8 @@ export const FolderNode = React.memo(function FolderNode({
onBlur={handleRenameSubmit} onBlur={handleRenameSubmit}
onKeyDown={handleRenameKeyDown} onKeyDown={handleRenameKeyDown}
onClick={(e) => e.stopPropagation()} onClick={(e) => e.stopPropagation()}
className="flex-1 min-w-0 rounded border border-primary bg-background px-1 py-0.5 text-sm outline-none" placeholder="Enter folder name"
className="flex-1 min-w-0 bg-transparent px-1 py-0.5 text-sm outline-none caret-primary placeholder:text-muted-foreground/50"
/> />
) : ( ) : (
<span className="flex-1 min-w-0 truncate">{folder.name}</span> <span className="flex-1 min-w-0 truncate">{folder.name}</span>
@ -279,13 +282,13 @@ export const FolderNode = React.memo(function FolderNode({
<Button <Button
variant="ghost" variant="ghost"
size="icon" size="icon"
className="h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity" className="hidden sm:inline-flex h-6 w-6 shrink-0 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={(e) => e.stopPropagation()} onClick={(e) => e.stopPropagation()}
> >
<MoreHorizontal className="h-3.5 w-3.5" /> <MoreHorizontal className="h-3.5 w-3.5" />
</Button> </Button>
</DropdownMenuTrigger> </DropdownMenuTrigger>
<DropdownMenuContent align="end" className="w-48"> <DropdownMenuContent align="end" className="w-40">
<DropdownMenuItem <DropdownMenuItem
onClick={(e) => { onClick={(e) => {
e.stopPropagation(); e.stopPropagation();
@ -301,7 +304,7 @@ export const FolderNode = React.memo(function FolderNode({
startRename(); startRename();
}} }}
> >
<Pencil className="mr-2 h-4 w-4" /> <PenLine className="mr-2 h-4 w-4" />
Rename Rename
</DropdownMenuItem> </DropdownMenuItem>
<DropdownMenuItem <DropdownMenuItem
@ -313,7 +316,6 @@ export const FolderNode = React.memo(function FolderNode({
<Move className="mr-2 h-4 w-4" /> <Move className="mr-2 h-4 w-4" />
Move to... Move to...
</DropdownMenuItem> </DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuItem <DropdownMenuItem
className="text-destructive focus:text-destructive" className="text-destructive focus:text-destructive"
onClick={(e) => { onClick={(e) => {
@ -330,21 +332,20 @@ export const FolderNode = React.memo(function FolderNode({
</div> </div>
</ContextMenuTrigger> </ContextMenuTrigger>
{!isRenaming && ( {!isRenaming && contextMenuOpen && (
<ContextMenuContent className="w-48"> <ContextMenuContent className="w-40">
<ContextMenuItem onClick={() => onCreateSubfolder(folder.id)}> <ContextMenuItem onClick={() => onCreateSubfolder(folder.id)}>
<FolderPlus className="mr-2 h-4 w-4" /> <FolderPlus className="mr-2 h-4 w-4" />
New subfolder New subfolder
</ContextMenuItem> </ContextMenuItem>
<ContextMenuItem onClick={() => startRename()}> <ContextMenuItem onClick={() => startRename()}>
<Pencil className="mr-2 h-4 w-4" /> <PenLine className="mr-2 h-4 w-4" />
Rename Rename
</ContextMenuItem> </ContextMenuItem>
<ContextMenuItem onClick={() => onMove(folder)}> <ContextMenuItem onClick={() => onMove(folder)}>
<Move className="mr-2 h-4 w-4" /> <Move className="mr-2 h-4 w-4" />
Move to... Move to...
</ContextMenuItem> </ContextMenuItem>
<ContextMenuSeparator />
<ContextMenuItem <ContextMenuItem
className="text-destructive focus:text-destructive" className="text-destructive focus:text-destructive"
onClick={() => onDelete(folder)} onClick={() => onDelete(folder)}

View file

@ -124,10 +124,18 @@ export function FolderPickerDialog({
return ( return (
<Dialog open={open} onOpenChange={onOpenChange}> <Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-sm"> <DialogContent className="select-none max-w-[90vw] sm:max-w-sm p-4 sm:p-5 data-[state=open]:animate-none data-[state=closed]:animate-none">
<DialogHeader> <DialogHeader className="space-y-2 pb-2">
<DialogTitle>{title}</DialogTitle> <div className="flex items-center gap-2 sm:gap-3">
{description && <DialogDescription>{description}</DialogDescription>} <div className="flex-1 min-w-0">
<DialogTitle className="text-base sm:text-lg">{title}</DialogTitle>
{description && (
<DialogDescription className="text-xs sm:text-sm mt-0.5">
{description}
</DialogDescription>
)}
</div>
</div>
</DialogHeader> </DialogHeader>
<div className="max-h-[300px] overflow-y-auto rounded-md border p-1"> <div className="max-h-[300px] overflow-y-auto rounded-md border p-1">
@ -147,11 +155,17 @@ export function FolderPickerDialog({
{renderPickerLevel(null, 1)} {renderPickerLevel(null, 1)}
</div> </div>
<DialogFooter> <DialogFooter className="flex-row justify-end gap-2 pt-2 sm:pt-3">
<Button variant="outline" onClick={() => onOpenChange(false)}> <Button
variant="secondary"
onClick={() => onOpenChange(false)}
className="h-8 sm:h-9 text-xs sm:text-sm"
>
Cancel Cancel
</Button> </Button>
<Button onClick={handleConfirm}>Move here</Button> <Button onClick={handleConfirm} className="h-8 sm:h-9 text-xs sm:text-sm">
Move here
</Button>
</DialogFooter> </DialogFooter>
</DialogContent> </DialogContent>
</Dialog> </Dialog>

View file

@ -1,8 +1,8 @@
"use client"; "use client";
import { useAtom } from "jotai"; import { useAtom } from "jotai";
import { TreePine } from "lucide-react"; import { CirclePlus } from "lucide-react";
import { useCallback, useMemo } from "react"; import { useCallback, useMemo, useState } from "react";
import { DndProvider } from "react-dnd"; import { DndProvider } from "react-dnd";
import { HTML5Backend } from "react-dnd-html5-backend"; import { HTML5Backend } from "react-dnd-html5-backend";
import { renamingFolderIdAtom } from "@/atoms/documents/folder.atoms"; import { renamingFolderIdAtom } from "@/atoms/documents/folder.atoms";
@ -28,6 +28,7 @@ interface FolderTreeViewProps {
onEditDocument: (doc: DocumentNodeDoc) => void; onEditDocument: (doc: DocumentNodeDoc) => void;
onDeleteDocument: (doc: DocumentNodeDoc) => void; onDeleteDocument: (doc: DocumentNodeDoc) => void;
onMoveDocument: (doc: DocumentNodeDoc) => void; onMoveDocument: (doc: DocumentNodeDoc) => void;
onExportDocument?: (doc: DocumentNodeDoc, format: string) => void;
activeTypes: DocumentTypeEnum[]; activeTypes: DocumentTypeEnum[];
onDropIntoFolder?: ( onDropIntoFolder?: (
itemType: "folder" | "document", itemType: "folder" | "document",
@ -62,6 +63,7 @@ export function FolderTreeView({
onEditDocument, onEditDocument,
onDeleteDocument, onDeleteDocument,
onMoveDocument, onMoveDocument,
onExportDocument,
activeTypes, activeTypes,
onDropIntoFolder, onDropIntoFolder,
onReorderFolder, onReorderFolder,
@ -80,6 +82,8 @@ export function FolderTreeView({
return counts; return counts;
}, [folders, foldersByParent, docsByFolder]); }, [folders, foldersByParent, docsByFolder]);
const [openContextMenuId, setOpenContextMenuId] = useState<string | null>(null);
// Single subscription for rename state — derived boolean passed to each FolderNode // Single subscription for rename state — derived boolean passed to each FolderNode
const [renamingFolderId, setRenamingFolderId] = useAtom(renamingFolderIdAtom); const [renamingFolderId, setRenamingFolderId] = useAtom(renamingFolderIdAtom);
const handleStartRename = useCallback( const handleStartRename = useCallback(
@ -157,6 +161,8 @@ export function FolderTreeView({
onDropIntoFolder={onDropIntoFolder} onDropIntoFolder={onDropIntoFolder}
onReorderFolder={onReorderFolder} onReorderFolder={onReorderFolder}
siblingPositions={siblingPositions} siblingPositions={siblingPositions}
contextMenuOpen={openContextMenuId === `folder-${f.id}`}
onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `folder-${f.id}` : null)}
/> />
); );
@ -177,6 +183,9 @@ export function FolderTreeView({
onEdit={onEditDocument} onEdit={onEditDocument}
onDelete={onDeleteDocument} onDelete={onDeleteDocument}
onMove={onMoveDocument} onMove={onMoveDocument}
onExport={onExportDocument}
contextMenuOpen={openContextMenuId === `doc-${d.id}`}
onContextMenuOpenChange={(open) => setOpenContextMenuId(open ? `doc-${d.id}` : null)}
/> />
); );
} }
@ -189,7 +198,7 @@ export function FolderTreeView({
if (treeNodes.length === 0 && folders.length === 0 && documents.length === 0) { if (treeNodes.length === 0 && folders.length === 0 && documents.length === 0) {
return ( return (
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground"> <div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
<TreePine className="h-10 w-10" /> <CirclePlus className="h-10 w-10 rotate-45" />
<p className="text-sm">No documents yet</p> <p className="text-sm">No documents yet</p>
</div> </div>
); );
@ -198,7 +207,7 @@ export function FolderTreeView({
if (treeNodes.length === 0 && activeTypes.length > 0) { if (treeNodes.length === 0 && activeTypes.length > 0) {
return ( return (
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground"> <div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
<TreePine className="h-10 w-10" /> <CirclePlus className="h-10 w-10 rotate-45" />
<p className="text-sm">No matching documents</p> <p className="text-sm">No matching documents</p>
</div> </div>
); );

View file

@ -185,7 +185,7 @@ function DateTimePickerField({
type="time" type="time"
value={time} value={time}
onChange={handleTimeChange} onChange={handleTimeChange}
className="w-[120px] text-sm shrink-0 pl-1.5 [&::-webkit-calendar-picker-indicator]:order-first [&::-webkit-calendar-picker-indicator]:mr-1" className="w-[120px] text-sm shrink-0 appearance-none [&::-webkit-calendar-picker-indicator]:hidden [&::-webkit-calendar-picker-indicator]:appearance-none"
/> />
</div> </div>
); );

View file

@ -396,10 +396,13 @@ export function AllPrivateChatsSidebarContent({
variant="ghost" variant="ghost"
size="icon" size="icon"
className={cn( className={cn(
"h-6 w-6 shrink-0", "h-6 w-6 shrink-0 hover:bg-transparent",
isMobile isMobile
? "opacity-0 pointer-events-none absolute" ? "opacity-0 pointer-events-none absolute"
: openDropdownId === thread.id
? "opacity-100"
: "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100", : "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100",
openDropdownId === thread.id && "bg-accent hover:bg-accent",
"transition-opacity" "transition-opacity"
)} )}
disabled={isBusy} disabled={isBusy}

View file

@ -396,10 +396,13 @@ export function AllSharedChatsSidebarContent({
variant="ghost" variant="ghost"
size="icon" size="icon"
className={cn( className={cn(
"h-6 w-6 shrink-0", "h-6 w-6 shrink-0 hover:bg-transparent",
isMobile isMobile
? "opacity-0 pointer-events-none absolute" ? "opacity-0 pointer-events-none absolute"
: openDropdownId === thread.id
? "opacity-100"
: "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100", : "md:opacity-0 md:group-hover:opacity-100 md:focus:opacity-100",
openDropdownId === thread.id && "bg-accent hover:bg-accent",
"transition-opacity" "transition-opacity"
)} )}
disabled={isBusy} disabled={isBusy}

View file

@ -79,14 +79,21 @@ export function ChatListItem({
: "bg-gradient-to-l from-sidebar from-60% to-transparent group-hover/item:from-accent", : "bg-gradient-to-l from-sidebar from-60% to-transparent group-hover/item:from-accent",
isMobile isMobile
? "opacity-0" ? "opacity-0"
: isActive : isActive || dropdownOpen
? "opacity-100" ? "opacity-100"
: "opacity-0 group-hover/item:opacity-100" : "opacity-0 group-hover/item:opacity-100"
)} )}
> >
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}> <DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
<DropdownMenuTrigger asChild> <DropdownMenuTrigger asChild>
<Button variant="ghost" size="icon" className="pointer-events-auto h-6 w-6"> <Button
variant="ghost"
size="icon"
className={cn(
"pointer-events-auto h-6 w-6 hover:bg-transparent",
dropdownOpen && "bg-accent hover:bg-accent"
)}
>
<MoreHorizontal className="h-3.5 w-3.5 text-muted-foreground" /> <MoreHorizontal className="h-3.5 w-3.5 text-muted-foreground" />
<span className="sr-only">{t("more_options")}</span> <span className="sr-only">{t("more_options")}</span>
</Button> </Button>

View file

@ -7,6 +7,7 @@ import { useParams } from "next/navigation";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
import { DocumentsFilters } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsFilters"; import { DocumentsFilters } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsFilters";
import { import {
DocumentsTableShell, DocumentsTableShell,
@ -33,6 +34,7 @@ import { useDocumentSearch } from "@/hooks/use-document-search";
import { useDocuments } from "@/hooks/use-documents"; import { useDocuments } from "@/hooks/use-documents";
import { useMediaQuery } from "@/hooks/use-media-query"; import { useMediaQuery } from "@/hooks/use-media-query";
import { foldersApiService } from "@/lib/apis/folders-api.service"; import { foldersApiService } from "@/lib/apis/folders-api.service";
import { authenticatedFetch } from "@/lib/auth-utils";
import { queries } from "@/zero/queries/index"; import { queries } from "@/zero/queries/index";
import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel"; import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel";
@ -234,6 +236,43 @@ export function DocumentsSidebar({
setFolderPickerOpen(true); setFolderPickerOpen(true);
}, []); }, []);
const handleExportDocument = useCallback(
async (doc: DocumentNodeDoc, format: string) => {
const safeTitle =
doc.title
.replace(/[^a-zA-Z0-9 _-]/g, "_")
.trim()
.slice(0, 80) || "document";
const ext = EXPORT_FILE_EXTENSIONS[format] ?? format;
try {
const response = await authenticatedFetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${doc.id}/export?format=${format}`,
{ method: "GET" }
);
if (!response.ok) {
const errorData = await response.json().catch(() => ({ detail: "Export failed" }));
throw new Error(errorData.detail || "Export failed");
}
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `${safeTitle}.${ext}`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
} catch (err) {
console.error(`Export ${format} failed:`, err);
toast.error(err instanceof Error ? err.message : `Export failed`);
}
},
[searchSpaceId]
);
const handleFolderPickerSelect = useCallback( const handleFolderPickerSelect = useCallback(
async (targetFolderId: number | null) => { async (targetFolderId: number | null) => {
if (!folderPickerTarget) return; if (!folderPickerTarget) return;
@ -606,6 +645,7 @@ export function DocumentsSidebar({
}} }}
onDeleteDocument={(doc) => handleDeleteDocument(doc.id)} onDeleteDocument={(doc) => handleDeleteDocument(doc.id)}
onMoveDocument={handleMoveDocument} onMoveDocument={handleMoveDocument}
onExportDocument={handleExportDocument}
activeTypes={activeTypes} activeTypes={activeTypes}
onDropIntoFolder={handleDropIntoFolder} onDropIntoFolder={handleDropIntoFolder}
onReorderFolder={handleReorderFolder} onReorderFolder={handleReorderFolder}
@ -617,7 +657,7 @@ export function DocumentsSidebar({
open={folderPickerOpen} open={folderPickerOpen}
onOpenChange={setFolderPickerOpen} onOpenChange={setFolderPickerOpen}
folders={treeFolders} folders={treeFolders}
title={folderPickerTarget?.type === "folder" ? "Move folder to..." : "Move document to..."} title={folderPickerTarget?.type === "folder" ? "Move folder to" : "Move document to"}
description="Select a destination folder, or choose Root to move to the top level." description="Select a destination folder, or choose Root to move to the top level."
disabledFolderIds={folderPickerTarget?.disabledIds} disabledFolderIds={folderPickerTarget?.disabledIds}
onSelect={handleFolderPickerSelect} onSelect={handleFolderPickerSelect}

View file

@ -199,7 +199,7 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS
className={cn( className={cn(
"w-full flex items-center gap-2.5 px-2.5 py-2 rounded-md transition-all", "w-full flex items-center gap-2.5 px-2.5 py-2 rounded-md transition-all",
"hover:bg-accent/50 dark:hover:bg-white/10 cursor-pointer", "hover:bg-accent/50 dark:hover:bg-white/10 cursor-pointer",
"focus:outline-none", "focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
isSelected && "bg-accent/80 dark:bg-white/10" isSelected && "bg-accent/80 dark:bg-white/10"
)} )}
> >
@ -248,7 +248,7 @@ export function ChatShareButton({ thread, onVisibilityChange, className }: ChatS
className={cn( className={cn(
"w-full flex items-center gap-2.5 px-2.5 py-2 rounded-md transition-all", "w-full flex items-center gap-2.5 px-2.5 py-2 rounded-md transition-all",
"hover:bg-accent/50 dark:hover:bg-white/10 cursor-pointer", "hover:bg-accent/50 dark:hover:bg-white/10 cursor-pointer",
"focus:outline-none", "focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
"disabled:opacity-50 disabled:cursor-not-allowed" "disabled:opacity-50 disabled:cursor-not-allowed"
)} )}
> >

View file

@ -7,7 +7,7 @@ import { useTheme } from "next-themes";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { createPortal } from "react-dom"; import { createPortal } from "react-dom";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms"; import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { useIsMobile } from "@/hooks/use-mobile"; import { useIsMobile } from "@/hooks/use-mobile";
@ -452,8 +452,8 @@ export function OnboardingTour() {
enabled: !!searchSpaceId, enabled: !!searchSpaceId,
}); });
// Get document type counts // Real-time document type counts via Zero
const { data: documentTypeCounts } = useAtomValue(documentTypeCountsAtom); const documentTypeCounts = useZeroDocumentTypeCounts(searchSpaceId);
// Get connectors // Get connectors
const { data: connectors = [] } = useAtomValue(connectorsAtom); const { data: connectors = [] } = useAtomValue(connectorsAtom);

View file

@ -15,10 +15,9 @@ import {
DropdownMenu, DropdownMenu,
DropdownMenuContent, DropdownMenuContent,
DropdownMenuItem, DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuSeparator,
DropdownMenuTrigger, DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"; } from "@/components/ui/dropdown-menu";
import { ExportDropdownItems, EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
import { useMediaQuery } from "@/hooks/use-media-query"; import { useMediaQuery } from "@/hooks/use-media-query";
import { baseApiService } from "@/lib/apis/base-api.service"; import { baseApiService } from "@/lib/apis/base-api.service";
import { authenticatedFetch } from "@/lib/auth-utils"; import { authenticatedFetch } from "@/lib/auth-utils";
@ -198,19 +197,6 @@ export function ReportPanelContent({
} }
}, [currentMarkdown]); }, [currentMarkdown]);
// Maps backend format values to download file extensions
const FILE_EXTENSIONS: Record<string, string> = {
pdf: "pdf",
docx: "docx",
html: "html",
latex: "tex",
epub: "epub",
odt: "odt",
plain: "txt",
md: "md",
};
// Export report
const handleExport = useCallback( const handleExport = useCallback(
async (format: string) => { async (format: string) => {
setExporting(format); setExporting(format);
@ -219,7 +205,7 @@ export function ReportPanelContent({
.replace(/[^a-zA-Z0-9 _-]/g, "_") .replace(/[^a-zA-Z0-9 _-]/g, "_")
.trim() .trim()
.slice(0, 80) || "report"; .slice(0, 80) || "report";
const ext = FILE_EXTENSIONS[format] ?? format; const ext = EXPORT_FILE_EXTENSIONS[format] ?? format;
try { try {
if (format === "md") { if (format === "md") {
if (!currentMarkdown) return; if (!currentMarkdown) return;
@ -329,68 +315,11 @@ export function ReportPanelContent({
align="start" align="start"
className={`min-w-[200px] select-none${insideDrawer ? " z-[100]" : ""}`} className={`min-w-[200px] select-none${insideDrawer ? " z-[100]" : ""}`}
> >
{!shareToken && ( <ExportDropdownItems
<> onExport={handleExport}
<DropdownMenuLabel className="text-xs text-muted-foreground"> exporting={exporting}
Documents showAllFormats={!shareToken}
</DropdownMenuLabel> />
<DropdownMenuItem
onClick={() => handleExport("pdf")}
disabled={exporting !== null}
>
PDF (.pdf)
</DropdownMenuItem>
<DropdownMenuItem
onClick={() => handleExport("docx")}
disabled={exporting !== null}
>
Word (.docx)
</DropdownMenuItem>
<DropdownMenuItem
onClick={() => handleExport("odt")}
disabled={exporting !== null}
>
OpenDocument (.odt)
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuLabel className="text-xs text-muted-foreground">
Web &amp; E-Book
</DropdownMenuLabel>
<DropdownMenuItem
onClick={() => handleExport("html")}
disabled={exporting !== null}
>
HTML (.html)
</DropdownMenuItem>
<DropdownMenuItem
onClick={() => handleExport("epub")}
disabled={exporting !== null}
>
EPUB (.epub)
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuLabel className="text-xs text-muted-foreground">
Source &amp; Plain
</DropdownMenuLabel>
<DropdownMenuItem
onClick={() => handleExport("latex")}
disabled={exporting !== null}
>
LaTeX (.tex)
</DropdownMenuItem>
</>
)}
<DropdownMenuItem onClick={() => handleExport("md")} disabled={exporting !== null}>
Markdown (.md)
</DropdownMenuItem>
{!shareToken && (
<DropdownMenuItem
onClick={() => handleExport("plain")}
disabled={exporting !== null}
>
Plain Text (.txt)
</DropdownMenuItem>
)}
</DropdownMenuContent> </DropdownMenuContent>
</DropdownMenu> </DropdownMenu>

View file

@ -27,6 +27,7 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
const { const {
data: searchSpace, data: searchSpace,
isLoading: loading, isLoading: loading,
isError,
refetch: fetchSearchSpace, refetch: fetchSearchSpace,
} = useQuery({ } = useQuery({
queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()), queryKey: cacheKeys.searchSpaces.detail(searchSpaceId.toString()),
@ -104,6 +105,17 @@ export function GeneralSettingsManager({ searchSpaceId }: GeneralSettingsManager
); );
} }
if (isError) {
return (
<div className="flex flex-col items-center justify-center gap-3 py-8 text-center">
<p className="text-sm text-destructive">Failed to load settings.</p>
<Button variant="outline" size="sm" onClick={() => fetchSearchSpace()}>
Retry
</Button>
</div>
);
}
return ( return (
<div className="space-y-4 md:space-y-6"> <div className="space-y-4 md:space-y-6">
<Alert className="bg-muted/50 py-3 md:py-4"> <Alert className="bg-muted/50 py-3 md:py-4">

View file

@ -0,0 +1,142 @@
"use client";
import { Loader2 } from "lucide-react";
import { DropdownMenuItem, DropdownMenuLabel, DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
import { ContextMenuItem } from "@/components/ui/context-menu";
export const EXPORT_FILE_EXTENSIONS: Record<string, string> = {
pdf: "pdf",
docx: "docx",
html: "html",
latex: "tex",
epub: "epub",
odt: "odt",
plain: "txt",
md: "md",
};
interface ExportMenuItemsProps {
onExport: (format: string) => void;
exporting: string | null;
/** Hide server-side formats (PDF, DOCX, etc.) — only show md */
showAllFormats?: boolean;
}
export function ExportDropdownItems({
onExport,
exporting,
showAllFormats = true,
}: ExportMenuItemsProps) {
const handle = (format: string) => (e: React.MouseEvent) => {
e.stopPropagation();
onExport(format);
};
return (
<>
{showAllFormats && (
<>
<DropdownMenuLabel className="text-xs text-muted-foreground">
Documents
</DropdownMenuLabel>
<DropdownMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
PDF (.pdf)
</DropdownMenuItem>
<DropdownMenuItem onClick={handle("docx")} disabled={exporting !== null}>
{exporting === "docx" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Word (.docx)
</DropdownMenuItem>
<DropdownMenuItem onClick={handle("odt")} disabled={exporting !== null}>
{exporting === "odt" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
OpenDocument (.odt)
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuLabel className="text-xs text-muted-foreground">
Web &amp; E-Book
</DropdownMenuLabel>
<DropdownMenuItem onClick={handle("html")} disabled={exporting !== null}>
{exporting === "html" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
HTML (.html)
</DropdownMenuItem>
<DropdownMenuItem onClick={handle("epub")} disabled={exporting !== null}>
{exporting === "epub" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
EPUB (.epub)
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuLabel className="text-xs text-muted-foreground">
Source &amp; Plain
</DropdownMenuLabel>
<DropdownMenuItem onClick={handle("latex")} disabled={exporting !== null}>
{exporting === "latex" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
LaTeX (.tex)
</DropdownMenuItem>
</>
)}
<DropdownMenuItem onClick={handle("md")} disabled={exporting !== null}>
{exporting === "md" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Markdown (.md)
</DropdownMenuItem>
{showAllFormats && (
<DropdownMenuItem onClick={handle("plain")} disabled={exporting !== null}>
{exporting === "plain" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Plain Text (.txt)
</DropdownMenuItem>
)}
</>
);
}
export function ExportContextItems({
onExport,
exporting,
showAllFormats = true,
}: ExportMenuItemsProps) {
const handle = (format: string) => (e: React.MouseEvent) => {
e.stopPropagation();
onExport(format);
};
return (
<>
{showAllFormats && (
<>
<ContextMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
PDF (.pdf)
</ContextMenuItem>
<ContextMenuItem onClick={handle("docx")} disabled={exporting !== null}>
{exporting === "docx" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Word (.docx)
</ContextMenuItem>
<ContextMenuItem onClick={handle("odt")} disabled={exporting !== null}>
{exporting === "odt" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
OpenDocument (.odt)
</ContextMenuItem>
<ContextMenuItem onClick={handle("html")} disabled={exporting !== null}>
{exporting === "html" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
HTML (.html)
</ContextMenuItem>
<ContextMenuItem onClick={handle("epub")} disabled={exporting !== null}>
{exporting === "epub" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
EPUB (.epub)
</ContextMenuItem>
<ContextMenuItem onClick={handle("latex")} disabled={exporting !== null}>
{exporting === "latex" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
LaTeX (.tex)
</ContextMenuItem>
</>
)}
<ContextMenuItem onClick={handle("md")} disabled={exporting !== null}>
{exporting === "md" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Markdown (.md)
</ContextMenuItem>
{showAllFormats && (
<ContextMenuItem onClick={handle("plain")} disabled={exporting !== null}>
{exporting === "plain" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
Plain Text (.txt)
</ContextMenuItem>
)}
</>
);
}

View file

@ -253,6 +253,12 @@ function ApprovalCard({
String(effectiveNewDescription ?? "") !== (event?.description ?? ""); String(effectiveNewDescription ?? "") !== (event?.description ?? "");
const buildFinalArgs = useCallback(() => { const buildFinalArgs = useCallback(() => {
const base = {
event_id: event?.event_id,
document_id: event?.document_id,
connector_id: account?.id,
};
if (pendingEdits) { if (pendingEdits) {
const attendeesArr = pendingEdits.attendees const attendeesArr = pendingEdits.attendees
? pendingEdits.attendees ? pendingEdits.attendees
@ -260,22 +266,38 @@ function ApprovalCard({
.map((e) => e.trim()) .map((e) => e.trim())
.filter(Boolean) .filter(Boolean)
: null; : null;
const origAttendees = event?.attendees?.map((a) => a.email) ?? [];
return { return {
event_id: event?.event_id, ...base,
document_id: event?.document_id, new_summary:
connector_id: account?.id, pendingEdits.summary && pendingEdits.summary !== (event?.summary ?? "")
new_summary: pendingEdits.summary || null, ? pendingEdits.summary
new_description: pendingEdits.description || null, : null,
new_start_datetime: pendingEdits.start_datetime || null, new_description:
new_end_datetime: pendingEdits.end_datetime || null, pendingEdits.description !== (event?.description ?? "")
new_location: pendingEdits.location || null, ? pendingEdits.description || null
new_attendees: attendeesArr, : null,
new_start_datetime:
pendingEdits.start_datetime && pendingEdits.start_datetime !== (event?.start ?? "")
? pendingEdits.start_datetime
: null,
new_end_datetime:
pendingEdits.end_datetime && pendingEdits.end_datetime !== (event?.end ?? "")
? pendingEdits.end_datetime
: null,
new_location:
pendingEdits.location !== (event?.location ?? "")
? pendingEdits.location || null
: null,
new_attendees:
attendeesArr && attendeesArr.join(",") !== origAttendees.join(",")
? attendeesArr
: null,
}; };
} }
return { return {
event_id: event?.event_id, ...base,
document_id: event?.document_id,
connector_id: account?.id,
new_summary: actionArgs.new_summary ?? null, new_summary: actionArgs.new_summary ?? null,
new_description: actionArgs.new_description ?? null, new_description: actionArgs.new_description ?? null,
new_start_datetime: actionArgs.new_start_datetime ?? null, new_start_datetime: actionArgs.new_start_datetime ?? null,

View file

@ -1,7 +1,7 @@
"use client"; "use client";
import { AnimatePresence, motion } from "motion/react"; import { AnimatePresence, motion } from "motion/react";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { createPortal } from "react-dom"; import { createPortal } from "react-dom";
function isVideoSrc(src: string) { function isVideoSrc(src: string) {
@ -17,6 +17,12 @@ function ExpandedMediaOverlay({
alt: string; alt: string;
onClose: () => void; onClose: () => void;
}) { }) {
const overlayRef = useRef<HTMLDivElement>(null);
useEffect(() => {
overlayRef.current?.focus();
}, []);
useEffect(() => { useEffect(() => {
const handleKey = (e: KeyboardEvent) => { const handleKey = (e: KeyboardEvent) => {
if (e.key === "Escape") onClose(); if (e.key === "Escape") onClose();
@ -52,12 +58,20 @@ function ExpandedMediaOverlay({
return createPortal( return createPortal(
<motion.div <motion.div
role="dialog"
aria-modal="true"
aria-label="Expanded media view"
tabIndex={-1}
ref={overlayRef}
initial={{ opacity: 0 }} initial={{ opacity: 0 }}
animate={{ opacity: 1 }} animate={{ opacity: 1 }}
exit={{ opacity: 0 }} exit={{ opacity: 0 }}
transition={{ duration: 0.2 }} transition={{ duration: 0.2 }}
className="fixed inset-0 z-100 flex items-center justify-center bg-black/70 p-4 backdrop-blur-sm sm:p-8" className="fixed inset-0 z-100 flex items-center justify-center bg-black/70 p-4 backdrop-blur-sm sm:p-8"
onClick={onClose} onClick={onClose}
onKeyDown={(e) => {
if (e.key === "Escape") onClose();
}}
> >
{mediaElement} {mediaElement}
</motion.div>, </motion.div>,

View file

@ -5,10 +5,10 @@ import {
FileText, FileText,
Film, Film,
Globe, Globe,
ImageIcon,
type LucideIcon, type LucideIcon,
Podcast, Podcast,
ScanLine, ScanLine,
Sparkles,
Wrench, Wrench,
} from "lucide-react"; } from "lucide-react";
@ -17,7 +17,7 @@ const TOOL_ICONS: Record<string, LucideIcon> = {
generate_podcast: Podcast, generate_podcast: Podcast,
generate_video_presentation: Film, generate_video_presentation: Film,
generate_report: FileText, generate_report: FileText,
generate_image: Sparkles, generate_image: ImageIcon,
scrape_webpage: ScanLine, scrape_webpage: ScanLine,
web_search: Globe, web_search: Globe,
search_surfsense_docs: BookOpen, search_surfsense_docs: BookOpen,

View file

@ -0,0 +1,31 @@
"use client";
import { useQuery } from "@rocicorp/zero/react";
import { useMemo } from "react";
import { queries } from "@/zero/queries";
/**
* Real-time document type counts derived from Zero's live document sync.
* Updates instantly as documents are created, deleted, or change type.
*/
export function useZeroDocumentTypeCounts(
searchSpaceId: number | string | null
): Record<string, number> | undefined {
const numericId = searchSpaceId != null ? Number(searchSpaceId) : null;
const [zeroDocuments] = useQuery(
queries.documents.bySpace({ searchSpaceId: numericId ?? -1 })
);
return useMemo(() => {
if (!zeroDocuments || numericId == null) return undefined;
const counts: Record<string, number> = {};
for (const doc of zeroDocuments) {
if (doc.id != null && doc.title != null && doc.title !== "") {
counts[doc.documentType] = (counts[doc.documentType] || 0) + 1;
}
}
return counts;
}, [zeroDocuments, numericId]);
}

View file

@ -17,7 +17,6 @@ export const cacheKeys = {
withQueryParams: (queries: GetDocumentsRequest["queryParams"]) => withQueryParams: (queries: GetDocumentsRequest["queryParams"]) =>
["documents-with-queries", ...(queries ? Object.values(queries) : [])] as const, ["documents-with-queries", ...(queries ? Object.values(queries) : [])] as const,
document: (documentId: string) => ["document", documentId] as const, document: (documentId: string) => ["document", documentId] as const,
typeCounts: (searchSpaceId?: string) => ["documents", "type-counts", searchSpaceId] as const,
byChunk: (chunkId: string) => ["documents", "by-chunk", chunkId] as const, byChunk: (chunkId: string) => ["documents", "by-chunk", chunkId] as const,
}, },
logs: { logs: {