Compare commits

...

53 commits

Author SHA1 Message Date
DESKTOP-RTLN3BA\$punk
80f775581b feat: implement quick assist mode detection in AssistantActionBar
- Added state management for quick assist mode using the Electron API.
- Introduced a useEffect hook to asynchronously check and set the quick assist mode based on the API response, enhancing the component's interactivity.
2026-04-07 05:11:41 -07:00
DESKTOP-RTLN3BA\$punk
518cacf56e refactor: improve AssistantActionBar functionality and UI elements
- Removed unused quick ask mode state and effect, simplifying the component logic.
- Updated tooltip descriptions for action buttons to provide clearer user guidance.
- Enhanced the conditional rendering for the quick assist feature, improving user interaction with the clipboard functionality.
2026-04-07 05:03:20 -07:00
Rohan Verma
be98b395b2
Merge pull request #1158 from MODSetter/dev_mod
feat: add active search space management to Electron API and UI
2026-04-07 04:48:35 -07:00
DESKTOP-RTLN3BA\$punk
7c6e52a0a5 feat: add active search space management to Electron API and UI
- Introduced IPC channels for getting and setting the active search space, enhancing user experience across the application.
- Updated the preload script to expose new API methods for active search space management.
- Modified the main window and quick ask functionalities to sync the active search space based on user navigation.
- Enhanced the desktop and web applications to allow users to select and manage their default search space seamlessly.
- Implemented automatic synchronization of the active search space during login and navigation events.
2026-04-07 04:45:48 -07:00
DESKTOP-RTLN3BA\$punk
b74ac8a608 feat: update shortcut icons and descriptions for improved clarity
- Replaced icons for "General Assist," "Quick Assist," and "Extreme Assist" shortcuts to better represent their functionalities.
- Updated descriptions for each shortcut to enhance user understanding of their actions.
- Refactored the layout of the shortcut recorder for a more streamlined user experience.
2026-04-07 04:22:22 -07:00
Rohan Verma
a4a4deeda0
Merge pull request #1157 from MODSetter/dev_mod
feat: added tray table and general assist mode
2026-04-07 03:44:18 -07:00
DESKTOP-RTLN3BA\$punk
27e9e8d873 feat: add general assist feature and enhance shortcut management
- Introduced a new "General Assist" shortcut, allowing users to open SurfSense from anywhere.
- Updated shortcut management to include the new general assist functionality in both the desktop and web applications.
- Enhanced the UI to reflect changes in shortcut labels and descriptions for better clarity.
- Improved the Electron API to support the new shortcut configuration.
2026-04-07 03:42:46 -07:00
DESKTOP-RTLN3BA\$punk
e574b5ec4a refactor: remove prompt picker display on quick ask text retrieval
- Eliminated the automatic display of the prompt picker when quick ask text is retrieved from the Electron API, streamlining the user experience.
2026-04-07 03:17:10 -07:00
Rohan Verma
a05bb4ae0c
Merge pull request #1156 from MODSetter/dev_mod
refactor: streamlined desktop app
2026-04-07 03:10:45 -07:00
DESKTOP-RTLN3BA\$punk
91ea293fa2 chore: linting 2026-04-07 03:10:06 -07:00
DESKTOP-RTLN3BA\$punk
82b5c7f19e Merge commit '056fc0e7ff' into dev_mod 2026-04-07 02:56:46 -07:00
DESKTOP-RTLN3BA\$punk
bb1dcd32b6 feat: enhance vision autocomplete service and UI feedback
- Optimized the vision autocomplete service by starting the SSE stream immediately and deriving KB search queries directly from window titles.
- Refactored the service to run KB filesystem pre-computation and agent graph compilation in parallel, improving performance.
- Updated the SuggestionPage component to handle new agent step data, displaying progress indicators for each step.
- Enhanced the CSS for the suggestion tooltip and agent activity indicators, improving the user interface and experience.
2026-04-07 02:49:24 -07:00
DESKTOP-RTLN3BA\$punk
49441233e7 feat: enhance keyboard shortcut management and improve app responsiveness
- Updated the development script to include a build step before launching the app.
- Refactored the registration of quick ask and autocomplete functionalities to be asynchronous, ensuring proper initialization.
- Introduced IPC channels for getting and setting keyboard shortcuts, allowing users to customize their experience.
- Enhanced the platform module to support better interaction with the Electron API for clipboard operations.
- Improved the user interface for managing keyboard shortcuts in the settings dialog, providing a more intuitive experience.
2026-04-07 00:43:40 -07:00
DESKTOP-RTLN3BA\$punk
e920923fa4 feat: implement auth token synchronization between Electron and web app
- Added IPC channels for getting and setting auth tokens in Electron.
- Implemented functions to sync tokens from localStorage to Electron and vice versa.
- Updated components to ensure tokens are retrieved from Electron when not available locally.
- Enhanced user authentication flow by integrating token management across windows.
2026-04-06 23:02:25 -07:00
Rohan Verma
056fc0e7ff
Merge pull request #1137 from AnishSarkar22/feat/unified-etl-pipeline
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
feat: Unified ETL pipeline
2026-04-06 20:44:35 -07:00
Anish Sarkar
8d810467dd refactor: add support for XHTML file conversion to markdown in document processors 2026-04-07 05:57:13 +05:30
Anish Sarkar
0a26a6c5bb chore: ran linting 2026-04-07 05:55:39 +05:30
Anish Sarkar
5803fe79da refactor: update filename handling in Google Drive connector to include Google Workspace file extensions, improving content extraction accuracy 2026-04-07 05:43:34 +05:30
Anish Sarkar
7f32dd068f refactor: update button rendering logic in connector views to improve loading state handling 2026-04-07 05:40:40 +05:30
Anish Sarkar
1b87719a92 refactor: enhance file skipping logic in Google Drive connector to check for Google Workspace files before unsupported extensions 2026-04-07 05:36:29 +05:30
Anish Sarkar
e4462292e4 refactor: update Google Drive indexer to return an additional unsupported file count, enhancing error reporting consistency 2026-04-07 05:30:10 +05:30
Anish Sarkar
aba5f6a124 refactor: improve file handling logic in Dropbox and OneDrive connectors to include unsupported file extension information 2026-04-07 05:19:23 +05:30
Anish Sarkar
a624c86b04 refactor: update file skipping logic in Dropbox, Google Drive, and OneDrive connectors to return unsupported extension information 2026-04-07 05:11:15 +05:30
Anish Sarkar
122be76133 refactor: update _index_selected_files method signatures in Dropbox, Google Drive, and OneDrive indexers to include unsupported file count, enhancing error reporting and consistency across connectors 2026-04-07 03:16:46 +05:30
Anish Sarkar
3a1d700817 refactor: enhance file skipping logic across Dropbox, Google Drive, and OneDrive connectors to return unsupported extensions, improving error reporting and maintainability 2026-04-07 03:16:34 +05:30
Anish Sarkar
e7beeb2a36 refactor: unify file skipping logic across Dropbox, Google Drive, and OneDrive connectors by replacing classification checks with a centralized service-based approach, enhancing maintainability and consistency in file handling 2026-04-07 02:19:31 +05:30
Anish Sarkar
f03bf05aaa refactor: enhance Google Drive indexer to support file extension filtering, improving file handling and error reporting 2026-04-06 22:34:49 +05:30
Anish Sarkar
0fb92b7c56 refactor: streamline file skipping logic in Dropbox indexer by removing redundant checks, improving code clarity 2026-04-06 22:17:50 +05:30
Anish Sarkar
63a75052ca Merge remote-tracking branch 'upstream/dev' into feat/unified-etl-pipeline 2026-04-06 22:04:51 +05:30
Anish Sarkar
dc7047f64d refactor: implement file type classification for supported extensions across Dropbox, Google Drive, and OneDrive connectors, enhancing file handling and error management 2026-04-06 22:03:47 +05:30
Anish Sarkar
47f4be08d9 refactor: remove allowed_formats from DocumentConverter initialization in DoclingService to allow acceptance of all supported formats 2026-04-06 19:31:42 +05:30
Anish Sarkar
caca491774 test: add unit tests for Dropbox integration, covering delta sync methods, file type filtering, and re-authentication behavior 2026-04-06 18:36:48 +05:30
Anish Sarkar
b5a15b7681 feat: implement cursor-based delta sync for Dropbox integration, enhancing file indexing efficiency and preserving folder cursors during re-authentication 2026-04-06 18:36:29 +05:30
Anish Sarkar
be622c417c refactor: update loading skeleton in PlateEditor and clean up dark mode styles in various components 2026-04-06 17:07:26 +05:30
Anish Sarkar
be7e73e615 refactor: enhance DocumentsFilters and FolderTreeView components for improved filter handling and search functionality 2026-04-06 14:41:53 +05:30
Anish Sarkar
3251f0e98d refactor: remove childCount prop from FolderNode and optimize FolderTreeView by eliminating unnecessary child count calculations 2026-04-06 13:56:28 +05:30
Anish Sarkar
8259fab254 refactor: update connector tabs to include search feedback and improve icon usage for better user experience 2026-04-06 13:27:49 +05:30
Anish Sarkar
02323e7b55 refactor: enhance DocumentsFilters component with ToggleGroup for folder creation and improve search functionality 2026-04-06 12:56:29 +05:30
Anish Sarkar
46c15c11da refactor: update layout and styling in DocumentUploadPopup for improved visual hierarchy and spacing 2026-04-06 12:29:55 +05:30
Anish Sarkar
742548847a refactor: optimize navigation items in LayoutDataProvider, enhance button layout in InboxSidebar with tooltip support, full width in PageUsageDisplay 2026-04-06 12:14:17 +05:30
Anish Sarkar
7fa1810d50 refactor: simplify CommentComposer button layout and update placeholder text in CommentItem 2026-04-05 23:14:54 +05:30
Anish Sarkar
c9e5fe9cdb refactor: update icon usage in CommentActions and enhance Tooltip component for mobile responsiveness 2026-04-05 23:02:17 +05:30
Anish Sarkar
1f162f52c3 feat: add tooltip functionality to DocumentNode for title overflow handling and refactor ChatShareButton by removing unnecessary Tooltip wrapper 2026-04-05 22:50:36 +05:30
Anish Sarkar
c6e94188eb refactor: remove destructive text classes from DocumentNode and enhance CreateSearchSpaceDialog with select-none and select-text classes 2026-04-05 18:23:32 +05:30
Anish Sarkar
f8913adaa3 test: add unit tests for content extraction from cloud connectors and ETL pipeline functionality 2026-04-05 17:46:04 +05:30
Anish Sarkar
87af012a60 refactor: streamline file processing by integrating ETL pipeline for all file types and removing redundant functions 2026-04-05 17:45:18 +05:30
Anish Sarkar
8224360afa refactor: unify file parsing logic across Dropbox, Google Drive, and OneDrive using the ETL pipeline 2026-04-05 17:30:29 +05:30
Anish Sarkar
1248363ca9 refactor: consolidate document processing logic and remove unused files and ETL strategies 2026-04-05 17:29:24 +05:30
Anish Sarkar
f40de6b695 feat: add parsers for Docling, LlamaCloud, and Unstructured to ETL pipeline 2026-04-05 17:27:24 +05:30
Anish Sarkar
2824410be2 feat: add plaintext parser to ETL pipeline for reading text files 2026-04-05 17:26:42 +05:30
Anish Sarkar
35582c9389 feat: add direct_convert module to ETL pipeline for file conversion 2026-04-05 17:26:29 +05:30
Anish Sarkar
02fc6f1d16 feat: add audio transcription functionality to ETL pipeline 2026-04-05 17:26:03 +05:30
Anish Sarkar
5d22349dc1 feat: implement ETL pipeline with file classification and extraction services 2026-04-05 17:25:25 +05:30
142 changed files with 6229 additions and 2838 deletions

View file

@ -0,0 +1,11 @@
"""Agent-based vision autocomplete with scoped filesystem exploration."""
from app.agents.autocomplete.autocomplete_agent import (
create_autocomplete_agent,
stream_autocomplete_agent,
)
__all__ = [
"create_autocomplete_agent",
"stream_autocomplete_agent",
]

View file

@ -0,0 +1,442 @@
"""Vision autocomplete agent with scoped filesystem exploration.
Converts the stateless single-shot vision autocomplete into an agent that
seeds a virtual filesystem from KB search results and lets the vision LLM
explore documents via ``ls``, ``read_file``, ``glob``, ``grep``, etc.
before generating the final completion.
Performance: KB search and agent graph compilation run in parallel so
the only sequential latency is KB-search (or agent compile, whichever is
slower) + the agent's LLM turns. There is no separate "query extraction"
LLM call the window title is used directly as the KB search query.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from deepagents.graph import BASE_AGENT_PROMPT
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from langchain.agents import create_agent
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
from app.agents.new_chat.middleware.knowledge_search import (
build_scoped_filesystem,
search_knowledge_base,
)
from app.services.new_streaming_service import VercelStreamingService
logger = logging.getLogger(__name__)
KB_TOP_K = 10
# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------
AUTOCOMPLETE_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
You will receive a screenshot of the user's screen. Your PRIMARY source of truth is the screenshot itself — the visual context determines what to write.
Your job:
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
2. Identify the text area where the user will type.
3. Generate the text the user most likely wants to write based on the visual context.
You also have access to the user's knowledge base documents via filesystem tools. However:
- ONLY consult the knowledge base if the screenshot clearly involves a topic where your KB documents are DIRECTLY relevant (e.g., the user is writing about a specific project/topic that matches a document title).
- Do NOT explore documents just because they exist. Most autocomplete requests can be answered purely from the screenshot.
- If you do read a document, only incorporate information that is 100% relevant to what the user is typing RIGHT NOW. Do not add extra details, background, or tangential information from the KB.
- Keep your output SHORT autocomplete should feel like a natural continuation, not an essay.
Key behavior:
- If the text area is EMPTY, draft a concise response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
- If the text area already has text, continue it naturally typically just a sentence or two.
Rules:
- Output ONLY the text to be inserted. No quotes, no explanations, no meta-commentary.
- Be CONCISE. Prefer a single paragraph or a few sentences. Autocomplete is a quick assist, not a full draft.
- Match the tone and formality of the surrounding context.
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
- Do NOT describe the screenshot or explain your reasoning.
- Do NOT cite or reference documents explicitly just let the knowledge inform your writing naturally.
- If you cannot determine what to write, output nothing.
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`
All file paths must start with a `/`.
- ls: list files and directories at a given path.
- read_file: read a file from the filesystem.
- write_file: create a temporary file in the session (not persisted).
- edit_file: edit a file in the session (not persisted for /documents/ files).
- glob: find files matching a pattern (e.g., "**/*.xml").
- grep: search for text within files.
## When to Use Filesystem Tools
BEFORE reaching for any tool, ask yourself: "Can I write a good completion purely from the screenshot?" If yes, just write it do NOT explore the KB.
Only use tools when:
- The user is clearly writing about a specific topic that likely has detailed information in their KB.
- You need a specific fact, name, number, or reference that the screenshot doesn't provide.
When you do use tools, be surgical:
- Check the `ls` output first. If no document title looks relevant, stop do not read files just to see what's there.
- If a title looks relevant, read only the `<chunk_index>` (first ~20 lines) and jump to matched chunks. Do not read entire documents.
- Extract only the specific information you need and move on to generating the completion.
## Reading Documents Efficiently
Documents are formatted as XML. Each document contains:
- `<document_metadata>` title, type, URL, etc.
- `<chunk_index>` a table of every chunk with its **line range** and a
`matched="true"` flag for chunks that matched the search query.
- `<document_content>` the actual chunks in original document order.
**Workflow**: read the first ~20 lines to see the `<chunk_index>`, identify
chunks marked `matched="true"`, then use `read_file(path, offset=<start_line>,
limit=<lines>)` to jump directly to those sections."""
APP_CONTEXT_BLOCK = """
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
def _build_autocomplete_system_prompt(app_name: str, window_title: str) -> str:
prompt = AUTOCOMPLETE_SYSTEM_PROMPT
if app_name:
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
return prompt
# ---------------------------------------------------------------------------
# Pre-compute KB filesystem (runs in parallel with agent compilation)
# ---------------------------------------------------------------------------
class _KBResult:
"""Container for pre-computed KB filesystem results."""
__slots__ = ("files", "ls_ai_msg", "ls_tool_msg")
def __init__(
self,
files: dict[str, Any] | None = None,
ls_ai_msg: AIMessage | None = None,
ls_tool_msg: ToolMessage | None = None,
) -> None:
self.files = files
self.ls_ai_msg = ls_ai_msg
self.ls_tool_msg = ls_tool_msg
@property
def has_documents(self) -> bool:
return bool(self.files)
async def precompute_kb_filesystem(
search_space_id: int,
query: str,
top_k: int = KB_TOP_K,
) -> _KBResult:
"""Search the KB and build the scoped filesystem outside the agent.
This is designed to be called via ``asyncio.gather`` alongside agent
graph compilation so the two run concurrently.
"""
if not query:
return _KBResult()
try:
search_results = await search_knowledge_base(
query=query,
search_space_id=search_space_id,
top_k=top_k,
)
if not search_results:
return _KBResult()
new_files, _ = await build_scoped_filesystem(
documents=search_results,
search_space_id=search_space_id,
)
if not new_files:
return _KBResult()
doc_paths = [
p
for p, v in new_files.items()
if p.startswith("/documents/") and v is not None
]
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
ai_msg = AIMessage(
content="",
tool_calls=[
{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}
],
)
tool_msg = ToolMessage(
content=str(doc_paths) if doc_paths else "No documents found.",
tool_call_id=tool_call_id,
)
return _KBResult(files=new_files, ls_ai_msg=ai_msg, ls_tool_msg=tool_msg)
except Exception:
logger.warning(
"KB pre-computation failed, proceeding without KB", exc_info=True
)
return _KBResult()
# ---------------------------------------------------------------------------
# Filesystem middleware — no save_document, no persistence
# ---------------------------------------------------------------------------
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
"""Filesystem middleware for autocomplete — read-only exploration only.
Strips ``save_document`` (permanent KB persistence) and passes
``search_space_id=None`` so ``write_file`` / ``edit_file`` stay ephemeral.
"""
def __init__(self) -> None:
super().__init__(search_space_id=None, created_by_id=None)
self.tools = [t for t in self.tools if t.name != "save_document"]
# ---------------------------------------------------------------------------
# Agent factory
# ---------------------------------------------------------------------------
async def _compile_agent(
llm: BaseChatModel,
app_name: str,
window_title: str,
) -> Any:
"""Compile the agent graph (CPU-bound, runs in a thread)."""
system_prompt = _build_autocomplete_system_prompt(app_name, window_title)
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
middleware = [
AutocompleteFilesystemMiddleware(),
PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
agent = await asyncio.to_thread(
create_agent,
llm,
system_prompt=final_system_prompt,
tools=[],
middleware=middleware,
)
return agent.with_config({"recursion_limit": 200})
async def create_autocomplete_agent(
llm: BaseChatModel,
*,
search_space_id: int,
kb_query: str,
app_name: str = "",
window_title: str = "",
) -> tuple[Any, _KBResult]:
"""Create the autocomplete agent and pre-compute KB in parallel.
Returns ``(agent, kb_result)`` so the caller can inject the pre-computed
filesystem into the agent's initial state without any middleware delay.
"""
agent, kb = await asyncio.gather(
_compile_agent(llm, app_name, window_title),
precompute_kb_filesystem(search_space_id, kb_query),
)
return agent, kb
# ---------------------------------------------------------------------------
# Streaming helper
# ---------------------------------------------------------------------------
async def stream_autocomplete_agent(
agent: Any,
input_data: dict[str, Any],
streaming_service: VercelStreamingService,
*,
emit_message_start: bool = True,
) -> AsyncGenerator[str, None]:
"""Stream agent events as Vercel SSE, with thinking steps for tool calls.
When ``emit_message_start`` is False the caller has already sent the
``message_start`` event (e.g. to show preparation steps before the agent
runs).
"""
thread_id = uuid.uuid4().hex
config = {"configurable": {"thread_id": thread_id}}
current_text_id: str | None = None
active_tool_depth = 0
thinking_step_counter = 0
tool_step_ids: dict[str, str] = {}
step_titles: dict[str, str] = {}
completed_step_ids: set[str] = set()
last_active_step_id: str | None = None
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
thinking_step_counter += 1
return f"autocomplete-step-{thinking_step_counter}"
def complete_current_step() -> str | None:
nonlocal last_active_step_id
if last_active_step_id and last_active_step_id not in completed_step_ids:
completed_step_ids.add(last_active_step_id)
title = step_titles.get(last_active_step_id, "Done")
event = streaming_service.format_thinking_step(
step_id=last_active_step_id,
title=title,
status="complete",
)
last_active_step_id = None
return event
return None
if emit_message_start:
yield streaming_service.format_message_start()
# Emit an initial "Generating completion" step so the UI immediately
# shows activity once the agent starts its first LLM call.
gen_step_id = next_thinking_step_id()
last_active_step_id = gen_step_id
step_titles[gen_step_id] = "Generating completion"
yield streaming_service.format_thinking_step(
step_id=gen_step_id,
title="Generating completion",
status="in_progress",
)
try:
async for event in agent.astream_events(
input_data, config=config, version="v2"
):
event_type = event.get("event", "")
if event_type == "on_chat_model_stream":
if active_tool_depth > 0:
continue
if "surfsense:internal" in event.get("tags", []):
continue
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
content = chunk.content
if content and isinstance(content, str):
if current_text_id is None:
step_event = complete_current_step()
if step_event:
yield step_event
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
yield streaming_service.format_text_delta(
current_text_id, content
)
elif event_type == "on_tool_start":
active_tool_depth += 1
tool_name = event.get("name", "unknown_tool")
run_id = event.get("run_id", "")
tool_input = event.get("data", {}).get("input", {})
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
current_text_id = None
step_event = complete_current_step()
if step_event:
yield step_event
tool_step_id = next_thinking_step_id()
tool_step_ids[run_id] = tool_step_id
last_active_step_id = tool_step_id
title, items = _describe_tool_call(tool_name, tool_input)
step_titles[tool_step_id] = title
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title=title,
status="in_progress",
items=items,
)
elif event_type == "on_tool_end":
active_tool_depth = max(0, active_tool_depth - 1)
run_id = event.get("run_id", "")
step_id = tool_step_ids.pop(run_id, None)
if step_id and step_id not in completed_step_ids:
completed_step_ids.add(step_id)
title = step_titles.get(step_id, "Done")
yield streaming_service.format_thinking_step(
step_id=step_id,
title=title,
status="complete",
)
if last_active_step_id == step_id:
last_active_step_id = None
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
step_event = complete_current_step()
if step_event:
yield step_event
yield streaming_service.format_finish()
yield streaming_service.format_done()
except Exception as e:
logger.error(f"Autocomplete agent streaming error: {e}", exc_info=True)
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
yield streaming_service.format_error("Autocomplete failed. Please try again.")
yield streaming_service.format_done()
def _describe_tool_call(tool_name: str, tool_input: Any) -> tuple[str, list[str]]:
"""Return a human-readable (title, items) for a tool call thinking step."""
inp = tool_input if isinstance(tool_input, dict) else {}
if tool_name == "ls":
path = inp.get("path", "/")
return "Listing files", [path]
if tool_name == "read_file":
fp = inp.get("file_path", "")
display = fp if len(fp) <= 80 else "" + fp[-77:]
return "Reading file", [display]
if tool_name == "write_file":
fp = inp.get("file_path", "")
display = fp if len(fp) <= 80 else "" + fp[-77:]
return "Writing file", [display]
if tool_name == "edit_file":
fp = inp.get("file_path", "")
display = fp if len(fp) <= 80 else "" + fp[-77:]
return "Editing file", [display]
if tool_name == "glob":
pat = inp.get("pattern", "")
base = inp.get("path", "/")
return "Searching files", [f"{pat} in {base}"]
if tool_name == "grep":
pat = inp.get("pattern", "")
path = inp.get("path", "")
display_pat = pat[:60] + ("" if len(pat) > 60 else "")
return "Searching content", [
f'"{display_pat}"' + (f" in {path}" if path else "")
]
return f"Using {tool_name}", []

View file

@ -225,6 +225,55 @@ class DropboxClient:
return all_items, None
async def get_latest_cursor(self, path: str = "") -> tuple[str | None, str | None]:
"""Get a cursor representing the current state of a folder.
Uses /2/files/list_folder/get_latest_cursor so we can later call
get_changes to receive only incremental updates.
"""
resp = await self._request(
"/2/files/list_folder/get_latest_cursor",
{"path": path, "recursive": False, "include_non_downloadable_files": True},
)
if resp.status_code != 200:
return None, f"Failed to get cursor: {resp.status_code} - {resp.text}"
return resp.json().get("cursor"), None
async def get_changes(
self, cursor: str
) -> tuple[list[dict[str, Any]], str | None, str | None]:
"""Fetch incremental changes since the given cursor.
Calls /2/files/list_folder/continue and handles pagination.
Returns (entries, new_cursor, error).
"""
all_entries: list[dict[str, Any]] = []
resp = await self._request("/2/files/list_folder/continue", {"cursor": cursor})
if resp.status_code == 401:
return [], None, "Dropbox authentication expired (401)"
if resp.status_code != 200:
return [], None, f"Failed to get changes: {resp.status_code} - {resp.text}"
data = resp.json()
all_entries.extend(data.get("entries", []))
while data.get("has_more"):
cursor = data["cursor"]
resp = await self._request(
"/2/files/list_folder/continue", {"cursor": cursor}
)
if resp.status_code != 200:
return (
all_entries,
data.get("cursor"),
f"Pagination failed: {resp.status_code}",
)
data = resp.json()
all_entries.extend(data.get("entries", []))
return all_entries, data.get("cursor"), None
async def get_metadata(self, path: str) -> tuple[dict[str, Any] | None, str | None]:
resp = await self._request("/2/files/get_metadata", {"path": path})
if resp.status_code != 200:

View file

@ -53,7 +53,8 @@ async def download_and_extract_content(
file_name = file.get("name", "Unknown")
file_id = file.get("id", "")
if should_skip_file(file):
skip, _unsup_ext = should_skip_file(file)
if skip:
return None, {}, "Skipping non-indexable item"
logger.info(f"Downloading file for content extraction: {file_name}")
@ -87,9 +88,13 @@ async def download_and_extract_content(
if error:
return None, metadata, error
from app.connectors.onedrive.content_extractor import _parse_file_to_markdown
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
result = await EtlPipelineService().extract(
EtlRequest(file_path=temp_file_path, filename=file_name)
)
markdown = result.markdown_content
return markdown, metadata, None
except Exception as e:

View file

@ -1,8 +1,8 @@
"""File type handlers for Dropbox."""
PAPER_EXTENSION = ".paper"
from app.etl_pipeline.file_classifier import should_skip_for_service
SKIP_EXTENSIONS: frozenset[str] = frozenset()
PAPER_EXTENSION = ".paper"
MIME_TO_EXTENSION: dict[str, str] = {
"application/pdf": ".pdf",
@ -42,17 +42,25 @@ def is_paper_file(item: dict) -> bool:
return ext == PAPER_EXTENSION
def should_skip_file(item: dict) -> bool:
def should_skip_file(item: dict) -> tuple[bool, str | None]:
"""Skip folders and truly non-indexable files.
Paper docs are non-downloadable but exportable, so they are NOT skipped.
Returns (should_skip, unsupported_extension_or_None).
"""
if is_folder(item):
return True
return True, None
if is_paper_file(item):
return False
return False, None
if not item.get("is_downloadable", True):
return True
return True, None
from pathlib import PurePosixPath
from app.config import config as app_config
name = item.get("name", "")
ext = get_extension_from_name(name).lower()
return ext in SKIP_EXTENSIONS
if should_skip_for_service(name, app_config.ETL_SERVICE):
ext = PurePosixPath(name).suffix.lower()
return True, ext
return False, None

View file

@ -64,8 +64,10 @@ async def get_files_in_folder(
)
continue
files.extend(sub_files)
elif not should_skip_file(item):
files.append(item)
else:
skip, _unsup_ext = should_skip_file(item)
if not skip:
files.append(item)
return files, None

View file

@ -1,12 +1,9 @@
"""Content extraction for Google Drive files."""
import asyncio
import contextlib
import logging
import os
import tempfile
import threading
import time
from pathlib import Path
from typing import Any
@ -20,6 +17,7 @@ from .file_types import (
get_export_mime_type,
get_extension_from_mime,
is_google_workspace_file,
should_skip_by_extension,
should_skip_file,
)
@ -45,6 +43,11 @@ async def download_and_extract_content(
if should_skip_file(mime_type):
return None, {}, f"Skipping {mime_type}"
if not is_google_workspace_file(mime_type):
ext_skip, _unsup_ext = should_skip_by_extension(file_name)
if ext_skip:
return None, {}, f"Skipping unsupported extension: {file_name}"
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
drive_metadata: dict[str, Any] = {
@ -97,7 +100,10 @@ async def download_and_extract_content(
if error:
return None, drive_metadata, error
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
etl_filename = (
file_name + extension if is_google_workspace_file(mime_type) else file_name
)
markdown = await _parse_file_to_markdown(temp_file_path, etl_filename)
return markdown, drive_metadata, None
except Exception as e:
@ -110,99 +116,14 @@ async def download_and_extract_content(
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
"""Parse a local file to markdown using the configured ETL service."""
lower = filename.lower()
"""Parse a local file to markdown using the unified ETL pipeline."""
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
if lower.endswith((".md", ".markdown", ".txt")):
with open(file_path, encoding="utf-8") as f:
return f.read()
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
from litellm import atranscription
from app.config import config as app_config
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
t0 = time.monotonic()
logger.info(
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
)
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
logger.info(
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
)
text = result.get("text", "")
else:
with open(file_path, "rb") as audio_file:
kwargs: dict[str, Any] = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
resp = await atranscription(**kwargs)
text = resp.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
return f"# Transcription of {filename}\n\n{text}"
# Document files -- use configured ETL service
from app.config import config as app_config
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
file_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
return await convert_document_to_markdown(docs)
if app_config.ETL_SERVICE == "LLAMACLOUD":
from app.tasks.document_processors.file_processors import (
parse_with_llamacloud_retry,
)
result = await parse_with_llamacloud_retry(
file_path=file_path, estimated_pages=50
)
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
if not markdown_documents:
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
return markdown_documents[0].text
if app_config.ETL_SERVICE == "DOCLING":
from docling.document_converter import DocumentConverter
converter = DocumentConverter()
t0 = time.monotonic()
logger.info(
f"[docling] START file={filename} thread={threading.current_thread().name}"
)
result = await asyncio.to_thread(converter.convert, file_path)
logger.info(
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
)
return result.document.export_to_markdown()
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
result = await EtlPipelineService().extract(
EtlRequest(file_path=file_path, filename=filename)
)
return result.markdown_content
async def download_and_process_file(
@ -236,10 +157,14 @@ async def download_and_process_file(
file_name = file.get("name", "Unknown")
mime_type = file.get("mimeType", "")
# Skip folders and shortcuts
if should_skip_file(mime_type):
return None, f"Skipping {mime_type}", None
if not is_google_workspace_file(mime_type):
ext_skip, _unsup_ext = should_skip_by_extension(file_name)
if ext_skip:
return None, f"Skipping unsupported extension: {file_name}", None
logger.info(f"Downloading file: {file_name} ({mime_type})")
temp_file_path = None
@ -310,10 +235,13 @@ async def download_and_process_file(
"."
)[-1]
etl_filename = (
file_name + extension if is_google_workspace_file(mime_type) else file_name
)
logger.info(f"Processing {file_name} with Surfsense's file processor")
await process_file_in_background(
file_path=temp_file_path,
filename=file_name,
filename=etl_filename,
search_space_id=search_space_id,
user_id=user_id,
session=session,

View file

@ -1,5 +1,7 @@
"""File type handlers for Google Drive."""
from app.etl_pipeline.file_classifier import should_skip_for_service
GOOGLE_DOC = "application/vnd.google-apps.document"
GOOGLE_SHEET = "application/vnd.google-apps.spreadsheet"
GOOGLE_SLIDE = "application/vnd.google-apps.presentation"
@ -46,6 +48,21 @@ def should_skip_file(mime_type: str) -> bool:
return mime_type in [GOOGLE_FOLDER, GOOGLE_SHORTCUT]
def should_skip_by_extension(filename: str) -> tuple[bool, str | None]:
"""Check if the file extension is not parseable by the configured ETL service.
Returns (should_skip, unsupported_extension_or_None).
"""
from pathlib import PurePosixPath
from app.config import config as app_config
if should_skip_for_service(filename, app_config.ETL_SERVICE):
ext = PurePosixPath(filename).suffix.lower()
return True, ext
return False, None
def get_export_mime_type(mime_type: str) -> str | None:
"""Get export MIME type for Google Workspace files."""
return EXPORT_FORMATS.get(mime_type)

View file

@ -1,16 +1,9 @@
"""Content extraction for OneDrive files.
"""Content extraction for OneDrive files."""
Reuses the same ETL parsing logic as Google Drive since file parsing is
extension-based, not provider-specific.
"""
import asyncio
import contextlib
import logging
import os
import tempfile
import threading
import time
from pathlib import Path
from typing import Any
@ -31,7 +24,8 @@ async def download_and_extract_content(
item_id = file.get("id")
file_name = file.get("name", "Unknown")
if should_skip_file(file):
skip, _unsup_ext = should_skip_file(file)
if skip:
return None, {}, "Skipping non-indexable item"
file_info = file.get("file", {})
@ -84,98 +78,11 @@ async def download_and_extract_content(
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
"""Parse a local file to markdown using the configured ETL service.
"""Parse a local file to markdown using the unified ETL pipeline."""
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
Same logic as Google Drive -- file parsing is extension-based.
"""
lower = filename.lower()
if lower.endswith((".md", ".markdown", ".txt")):
with open(file_path, encoding="utf-8") as f:
return f.read()
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
from litellm import atranscription
from app.config import config as app_config
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
t0 = time.monotonic()
logger.info(
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
)
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
logger.info(
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
)
text = result.get("text", "")
else:
with open(file_path, "rb") as audio_file:
kwargs: dict[str, Any] = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
resp = await atranscription(**kwargs)
text = resp.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
return f"# Transcription of {filename}\n\n{text}"
from app.config import config as app_config
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
file_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
return await convert_document_to_markdown(docs)
if app_config.ETL_SERVICE == "LLAMACLOUD":
from app.tasks.document_processors.file_processors import (
parse_with_llamacloud_retry,
)
result = await parse_with_llamacloud_retry(
file_path=file_path, estimated_pages=50
)
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
if not markdown_documents:
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
return markdown_documents[0].text
if app_config.ETL_SERVICE == "DOCLING":
from docling.document_converter import DocumentConverter
converter = DocumentConverter()
t0 = time.monotonic()
logger.info(
f"[docling] START file={filename} thread={threading.current_thread().name}"
)
result = await asyncio.to_thread(converter.convert, file_path)
logger.info(
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
)
return result.document.export_to_markdown()
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
result = await EtlPipelineService().extract(
EtlRequest(file_path=file_path, filename=filename)
)
return result.markdown_content

View file

@ -1,5 +1,7 @@
"""File type handlers for Microsoft OneDrive."""
from app.etl_pipeline.file_classifier import should_skip_for_service
ONEDRIVE_FOLDER_FACET = "folder"
ONENOTE_MIME = "application/msonenote"
@ -38,13 +40,28 @@ def is_folder(item: dict) -> bool:
return ONEDRIVE_FOLDER_FACET in item
def should_skip_file(item: dict) -> bool:
"""Skip folders, OneNote files, remote items (shared links), and packages."""
def should_skip_file(item: dict) -> tuple[bool, str | None]:
"""Skip folders, OneNote files, remote items, packages, and unsupported extensions.
Returns (should_skip, unsupported_extension_or_None).
The second element is only set when the skip is due to an unsupported extension.
"""
if is_folder(item):
return True
return True, None
if "remoteItem" in item:
return True
return True, None
if "package" in item:
return True
return True, None
mime = item.get("file", {}).get("mimeType", "")
return mime in SKIP_MIME_TYPES
if mime in SKIP_MIME_TYPES:
return True, None
from pathlib import PurePosixPath
from app.config import config as app_config
name = item.get("name", "")
if should_skip_for_service(name, app_config.ETL_SERVICE):
ext = PurePosixPath(name).suffix.lower()
return True, ext
return False, None

View file

@ -71,8 +71,10 @@ async def get_files_in_folder(
)
continue
files.extend(sub_files)
elif not should_skip_file(item):
files.append(item)
else:
skip, _unsup_ext = should_skip_file(item)
if not skip:
files.append(item)
return files, None

View file

@ -0,0 +1,39 @@
import ssl
import httpx
LLAMACLOUD_MAX_RETRIES = 5
LLAMACLOUD_BASE_DELAY = 10
LLAMACLOUD_MAX_DELAY = 120
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
ssl.SSLError,
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadError,
httpx.ReadTimeout,
httpx.WriteError,
httpx.WriteTimeout,
httpx.RemoteProtocolError,
httpx.LocalProtocolError,
ConnectionError,
ConnectionResetError,
TimeoutError,
OSError,
)
UPLOAD_BYTES_PER_SECOND_SLOW = 100 * 1024
MIN_UPLOAD_TIMEOUT = 120
MAX_UPLOAD_TIMEOUT = 1800
BASE_JOB_TIMEOUT = 600
PER_PAGE_JOB_TIMEOUT = 60
def calculate_upload_timeout(file_size_bytes: int) -> float:
estimated_time = (file_size_bytes / UPLOAD_BYTES_PER_SECOND_SLOW) * 1.5
return max(MIN_UPLOAD_TIMEOUT, min(estimated_time, MAX_UPLOAD_TIMEOUT))
def calculate_job_timeout(estimated_pages: int, file_size_bytes: int) -> float:
page_based_timeout = BASE_JOB_TIMEOUT + (estimated_pages * PER_PAGE_JOB_TIMEOUT)
size_based_timeout = BASE_JOB_TIMEOUT + (file_size_bytes / (10 * 1024 * 1024)) * 60
return max(page_based_timeout, size_based_timeout)

View file

@ -0,0 +1,21 @@
from pydantic import BaseModel, field_validator
class EtlRequest(BaseModel):
file_path: str
filename: str
estimated_pages: int = 0
@field_validator("filename")
@classmethod
def filename_must_not_be_empty(cls, v: str) -> str:
if not v.strip():
raise ValueError("filename must not be empty")
return v
class EtlResult(BaseModel):
markdown_content: str
etl_service: str
actual_pages: int = 0
content_type: str

View file

@ -0,0 +1,90 @@
from app.config import config as app_config
from app.etl_pipeline.etl_document import EtlRequest, EtlResult
from app.etl_pipeline.exceptions import (
EtlServiceUnavailableError,
EtlUnsupportedFileError,
)
from app.etl_pipeline.file_classifier import FileCategory, classify_file
from app.etl_pipeline.parsers.audio import transcribe_audio
from app.etl_pipeline.parsers.direct_convert import convert_file_directly
from app.etl_pipeline.parsers.plaintext import read_plaintext
class EtlPipelineService:
"""Single pipeline for extracting markdown from files. All callers use this."""
async def extract(self, request: EtlRequest) -> EtlResult:
category = classify_file(request.filename)
if category == FileCategory.UNSUPPORTED:
raise EtlUnsupportedFileError(
f"File type not supported for parsing: {request.filename}"
)
if category == FileCategory.PLAINTEXT:
content = read_plaintext(request.file_path)
return EtlResult(
markdown_content=content,
etl_service="PLAINTEXT",
content_type="plaintext",
)
if category == FileCategory.DIRECT_CONVERT:
content = convert_file_directly(request.file_path, request.filename)
return EtlResult(
markdown_content=content,
etl_service="DIRECT_CONVERT",
content_type="direct_convert",
)
if category == FileCategory.AUDIO:
content = await transcribe_audio(request.file_path, request.filename)
return EtlResult(
markdown_content=content,
etl_service="AUDIO",
content_type="audio",
)
return await self._extract_document(request)
async def _extract_document(self, request: EtlRequest) -> EtlResult:
from pathlib import PurePosixPath
from app.utils.file_extensions import get_document_extensions_for_service
etl_service = app_config.ETL_SERVICE
if not etl_service:
raise EtlServiceUnavailableError(
"No ETL_SERVICE configured. "
"Set ETL_SERVICE to UNSTRUCTURED, LLAMACLOUD, or DOCLING in your .env"
)
ext = PurePosixPath(request.filename).suffix.lower()
supported = get_document_extensions_for_service(etl_service)
if ext not in supported:
raise EtlUnsupportedFileError(
f"File type {ext} is not supported by {etl_service}"
)
if etl_service == "DOCLING":
from app.etl_pipeline.parsers.docling import parse_with_docling
content = await parse_with_docling(request.file_path, request.filename)
elif etl_service == "UNSTRUCTURED":
from app.etl_pipeline.parsers.unstructured import parse_with_unstructured
content = await parse_with_unstructured(request.file_path)
elif etl_service == "LLAMACLOUD":
from app.etl_pipeline.parsers.llamacloud import parse_with_llamacloud
content = await parse_with_llamacloud(
request.file_path, request.estimated_pages
)
else:
raise EtlServiceUnavailableError(f"Unknown ETL_SERVICE: {etl_service}")
return EtlResult(
markdown_content=content,
etl_service=etl_service,
content_type="document",
)

View file

@ -0,0 +1,10 @@
class EtlParseError(Exception):
"""Raised when an ETL parser fails to produce content."""
class EtlServiceUnavailableError(Exception):
"""Raised when the configured ETL_SERVICE is not recognised."""
class EtlUnsupportedFileError(Exception):
"""Raised when a file type cannot be parsed by any ETL pipeline."""

View file

@ -0,0 +1,137 @@
from enum import Enum
from pathlib import PurePosixPath
from app.utils.file_extensions import (
DOCUMENT_EXTENSIONS,
get_document_extensions_for_service,
)
PLAINTEXT_EXTENSIONS = frozenset(
{
".md",
".markdown",
".txt",
".text",
".json",
".jsonl",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".xml",
".css",
".scss",
".less",
".sass",
".py",
".pyw",
".pyi",
".pyx",
".js",
".jsx",
".ts",
".tsx",
".mjs",
".cjs",
".java",
".kt",
".kts",
".scala",
".groovy",
".c",
".h",
".cpp",
".cxx",
".cc",
".hpp",
".hxx",
".cs",
".fs",
".fsx",
".go",
".rs",
".rb",
".php",
".pl",
".pm",
".lua",
".swift",
".m",
".mm",
".r",
".jl",
".sh",
".bash",
".zsh",
".fish",
".bat",
".cmd",
".ps1",
".sql",
".graphql",
".gql",
".env",
".gitignore",
".dockerignore",
".editorconfig",
".makefile",
".cmake",
".log",
".rst",
".tex",
".bib",
".org",
".adoc",
".asciidoc",
".vue",
".svelte",
".astro",
".tf",
".hcl",
".proto",
}
)
AUDIO_EXTENSIONS = frozenset(
{".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm"}
)
DIRECT_CONVERT_EXTENSIONS = frozenset({".csv", ".tsv", ".html", ".htm", ".xhtml"})
class FileCategory(Enum):
PLAINTEXT = "plaintext"
AUDIO = "audio"
DIRECT_CONVERT = "direct_convert"
UNSUPPORTED = "unsupported"
DOCUMENT = "document"
def classify_file(filename: str) -> FileCategory:
suffix = PurePosixPath(filename).suffix.lower()
if suffix in PLAINTEXT_EXTENSIONS:
return FileCategory.PLAINTEXT
if suffix in AUDIO_EXTENSIONS:
return FileCategory.AUDIO
if suffix in DIRECT_CONVERT_EXTENSIONS:
return FileCategory.DIRECT_CONVERT
if suffix in DOCUMENT_EXTENSIONS:
return FileCategory.DOCUMENT
return FileCategory.UNSUPPORTED
def should_skip_for_service(filename: str, etl_service: str | None) -> bool:
"""Return True if *filename* cannot be processed by *etl_service*.
Plaintext, audio, and direct-convert files are parser-agnostic and never
skipped. Document files are checked against the per-parser extension set.
"""
category = classify_file(filename)
if category == FileCategory.UNSUPPORTED:
return True
if category == FileCategory.DOCUMENT:
suffix = PurePosixPath(filename).suffix.lower()
return suffix not in get_document_extensions_for_service(etl_service)
return False

View file

@ -0,0 +1,34 @@
from litellm import atranscription
from app.config import config as app_config
async def transcribe_audio(file_path: str, filename: str) -> str:
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
result = stt_service.transcribe_file(file_path)
text = result.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
else:
with open(file_path, "rb") as audio_file:
kwargs: dict = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
response = await atranscription(**kwargs)
text = response.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
return f"# Transcription of {filename}\n\n{text}"

View file

@ -0,0 +1,3 @@
from app.tasks.document_processors._direct_converters import convert_file_directly
__all__ = ["convert_file_directly"]

View file

@ -0,0 +1,26 @@
import warnings
from logging import ERROR, getLogger
async def parse_with_docling(file_path: str, filename: str) -> str:
from app.services.docling_service import create_docling_service
docling_service = create_docling_service()
pdfminer_logger = getLogger("pdfminer")
original_level = pdfminer_logger.level
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pdfminer")
warnings.filterwarnings(
"ignore", message=".*Cannot set gray non-stroke color.*"
)
warnings.filterwarnings("ignore", message=".*invalid float value.*")
pdfminer_logger.setLevel(ERROR)
try:
result = await docling_service.process_document(file_path, filename)
finally:
pdfminer_logger.setLevel(original_level)
return result["content"]

View file

@ -0,0 +1,123 @@
import asyncio
import logging
import os
import random
import httpx
from app.config import config as app_config
from app.etl_pipeline.constants import (
LLAMACLOUD_BASE_DELAY,
LLAMACLOUD_MAX_DELAY,
LLAMACLOUD_MAX_RETRIES,
LLAMACLOUD_RETRYABLE_EXCEPTIONS,
PER_PAGE_JOB_TIMEOUT,
calculate_job_timeout,
calculate_upload_timeout,
)
async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
file_size_bytes = os.path.getsize(file_path)
file_size_mb = file_size_bytes / (1024 * 1024)
upload_timeout = calculate_upload_timeout(file_size_bytes)
job_timeout = calculate_job_timeout(estimated_pages, file_size_bytes)
custom_timeout = httpx.Timeout(
connect=120.0,
read=upload_timeout,
write=upload_timeout,
pool=120.0,
)
logging.info(
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
f"job_timeout={job_timeout:.0f}s"
)
last_exception = None
attempt_errors: list[str] = []
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
try:
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1,
verbose=True,
language="en",
result_type=ResultType.MD,
max_timeout=int(max(2000, job_timeout + upload_timeout)),
job_timeout_in_seconds=job_timeout,
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
custom_client=custom_client,
)
result = await parser.aparse(file_path)
if attempt > 1:
logging.info(
f"LlamaCloud upload succeeded on attempt {attempt} after "
f"{len(attempt_errors)} failures"
)
if hasattr(result, "get_markdown_documents"):
markdown_docs = result.get_markdown_documents(split_by_page=False)
if markdown_docs and hasattr(markdown_docs[0], "text"):
return markdown_docs[0].text
if hasattr(result, "pages") and result.pages:
return "\n\n".join(
p.md for p in result.pages if hasattr(p, "md") and p.md
)
return str(result)
if isinstance(result, list):
if result and hasattr(result[0], "text"):
return result[0].text
return "\n\n".join(
doc.page_content if hasattr(doc, "page_content") else str(doc)
for doc in result
)
return str(result)
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
last_exception = e
error_type = type(e).__name__
error_msg = str(e)[:200]
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
if attempt < LLAMACLOUD_MAX_RETRIES:
base_delay = min(
LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1)),
LLAMACLOUD_MAX_DELAY,
)
jitter = base_delay * 0.25 * (2 * random.random() - 1)
delay = base_delay + jitter
logging.warning(
f"LlamaCloud upload failed "
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): "
f"{error_type}. File: {file_size_mb:.1f}MB. "
f"Retrying in {delay:.0f}s..."
)
await asyncio.sleep(delay)
else:
logging.error(
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} "
f"attempts. File size: {file_size_mb:.1f}MB, "
f"Pages: {estimated_pages}. "
f"Errors: {'; '.join(attempt_errors)}"
)
except Exception:
raise
raise last_exception or RuntimeError(
f"LlamaCloud parsing failed after {LLAMACLOUD_MAX_RETRIES} retries. "
f"File size: {file_size_mb:.1f}MB"
)

View file

@ -0,0 +1,8 @@
def read_plaintext(file_path: str) -> str:
with open(file_path, encoding="utf-8", errors="replace") as f:
content = f.read()
if "\x00" in content:
raise ValueError(
f"File contains null bytes — likely a binary file opened as text: {file_path}"
)
return content

View file

@ -0,0 +1,14 @@
async def parse_with_unstructured(file_path: str) -> str:
from langchain_unstructured import UnstructuredLoader
loader = UnstructuredLoader(
file_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
return "\n\n".join(doc.page_content for doc in docs if doc.page_content)

View file

@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
@ -31,8 +31,11 @@ async def vision_autocomplete_stream(
return StreamingResponse(
stream_vision_autocomplete(
body.screenshot, body.search_space_id, session,
app_name=body.app_name, window_title=body.window_title,
body.screenshot,
body.search_space_id,
session,
app_name=body.app_name,
window_title=body.window_title,
),
media_type="text/event-stream",
headers={

View file

@ -311,9 +311,11 @@ async def dropbox_callback(
)
existing_cursor = db_connector.config.get("cursor")
existing_folder_cursors = db_connector.config.get("folder_cursors")
db_connector.config = {
**connector_config,
"cursor": existing_cursor,
"folder_cursors": existing_folder_cursors,
"auth_expired": False,
}
flag_modified(db_connector, "config")

View file

@ -2477,6 +2477,8 @@ async def run_google_drive_indexing(
stage="fetching",
)
total_unsupported = 0
# Index each folder with indexing options
for folder in items.folders:
try:
@ -2484,6 +2486,7 @@ async def run_google_drive_indexing(
indexed_count,
skipped_count,
error_message,
unsupported_count,
) = await index_google_drive_files(
session,
connector_id,
@ -2497,6 +2500,7 @@ async def run_google_drive_indexing(
include_subfolders=indexing_options.include_subfolders,
)
total_skipped += skipped_count
total_unsupported += unsupported_count
if error_message:
errors.append(f"Folder '{folder.name}': {error_message}")
else:
@ -2572,6 +2576,7 @@ async def run_google_drive_indexing(
indexed_count=total_indexed,
error_message=error_message,
skipped_count=total_skipped,
unsupported_count=total_unsupported,
)
except Exception as e:
@ -2642,7 +2647,12 @@ async def run_onedrive_indexing(
stage="fetching",
)
total_indexed, total_skipped, error_message = await index_onedrive_files(
(
total_indexed,
total_skipped,
error_message,
total_unsupported,
) = await index_onedrive_files(
session,
connector_id,
search_space_id,
@ -2683,6 +2693,7 @@ async def run_onedrive_indexing(
indexed_count=total_indexed,
error_message=error_message,
skipped_count=total_skipped,
unsupported_count=total_unsupported,
)
except Exception as e:
@ -2750,7 +2761,12 @@ async def run_dropbox_indexing(
stage="fetching",
)
total_indexed, total_skipped, error_message = await index_dropbox_files(
(
total_indexed,
total_skipped,
error_message,
total_unsupported,
) = await index_dropbox_files(
session,
connector_id,
search_space_id,
@ -2791,6 +2807,7 @@ async def run_dropbox_indexing(
indexed_count=total_indexed,
error_message=error_message,
skipped_count=total_skipped,
unsupported_count=total_unsupported,
)
except Exception as e:

View file

@ -111,9 +111,8 @@ class DoclingService:
pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend
)
# Initialize DocumentConverter
self.converter = DocumentConverter(
format_options={InputFormat.PDF: pdf_format_option}
format_options={InputFormat.PDF: pdf_format_option},
)
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"

View file

@ -421,6 +421,7 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
error_message: str | None = None,
is_warning: bool = False,
skipped_count: int | None = None,
unsupported_count: int | None = None,
) -> Notification:
"""
Update notification when connector indexing completes.
@ -428,10 +429,11 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
Args:
session: Database session
notification: Notification to update
indexed_count: Total number of items indexed
indexed_count: Total number of files indexed
error_message: Error message if indexing failed, or warning message (optional)
is_warning: If True, treat error_message as a warning (success case) rather than an error
skipped_count: Number of items skipped (e.g., duplicates) - optional
skipped_count: Number of files skipped (e.g., unchanged) - optional
unsupported_count: Number of files skipped because the ETL parser doesn't support them
Returns:
Updated notification
@ -440,52 +442,45 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
"connector_name", "Connector"
)
# Build the skipped text if there are skipped items
skipped_text = ""
if skipped_count and skipped_count > 0:
skipped_item_text = "item" if skipped_count == 1 else "items"
skipped_text = (
f" ({skipped_count} {skipped_item_text} skipped - already indexed)"
)
unsupported_text = ""
if unsupported_count and unsupported_count > 0:
file_word = "file was" if unsupported_count == 1 else "files were"
unsupported_text = f" {unsupported_count} {file_word} not supported."
# If there's an error message but items were indexed, treat it as a warning (partial success)
# If is_warning is True, treat it as success even with 0 items (e.g., duplicates found)
# Otherwise, treat it as a failure
if error_message:
if indexed_count > 0:
# Partial success with warnings (e.g., duplicate content from other connectors)
title = f"Ready: {connector_name}"
item_text = "item" if indexed_count == 1 else "items"
message = f"Now searchable! {indexed_count} {item_text} synced{skipped_text}. Note: {error_message}"
file_text = "file" if indexed_count == 1 else "files"
message = f"Now searchable! {indexed_count} {file_text} synced.{unsupported_text} Note: {error_message}"
status = "completed"
elif is_warning:
# Warning case (e.g., duplicates found) - treat as success
title = f"Ready: {connector_name}"
message = f"Sync completed{skipped_text}. {error_message}"
message = f"Sync complete.{unsupported_text} {error_message}"
status = "completed"
else:
# Complete failure
title = f"Failed: {connector_name}"
message = f"Sync failed: {error_message}"
if unsupported_text:
message += unsupported_text
status = "failed"
else:
title = f"Ready: {connector_name}"
if indexed_count == 0:
if skipped_count and skipped_count > 0:
skipped_item_text = "item" if skipped_count == 1 else "items"
message = f"Already up to date! {skipped_count} {skipped_item_text} skipped (already indexed)."
if unsupported_count and unsupported_count > 0:
message = f"Sync complete.{unsupported_text}"
else:
message = "Already up to date! No new items to sync."
message = "Already up to date!"
else:
item_text = "item" if indexed_count == 1 else "items"
message = (
f"Now searchable! {indexed_count} {item_text} synced{skipped_text}."
)
file_text = "file" if indexed_count == 1 else "files"
message = f"Now searchable! {indexed_count} {file_text} synced."
if unsupported_text:
message += unsupported_text
status = "completed"
metadata_updates = {
"indexed_count": indexed_count,
"skipped_count": skipped_count or 0,
"unsupported_count": unsupported_count or 0,
"sync_stage": "completed"
if (not error_message or is_warning or indexed_count > 0)
else "failed",

View file

@ -1,139 +1,40 @@
import logging
from typing import AsyncGenerator
"""Vision autocomplete service — agent-based with scoped filesystem.
from langchain_core.messages import HumanMessage, SystemMessage
Optimized pipeline:
1. Start the SSE stream immediately so the UI shows progress.
2. Derive a KB search query from window_title (no separate LLM call).
3. Run KB filesystem pre-computation and agent graph compilation in PARALLEL.
4. Inject pre-computed KB files as initial state and stream the agent.
"""
import logging
from collections.abc import AsyncGenerator
from langchain_core.messages import HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.agents.autocomplete import create_autocomplete_agent, stream_autocomplete_agent
from app.services.llm_service import get_vision_llm
from app.services.new_streaming_service import VercelStreamingService
logger = logging.getLogger(__name__)
KB_TOP_K = 5
KB_MAX_CHARS = 4000
EXTRACT_QUERY_PROMPT = """Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
EXTRACT_QUERY_PROMPT_WITH_APP = """The user is currently in the application "{app_name}" with the window titled "{window_title}".
Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
VISION_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
You will receive a screenshot of the user's screen. Your job:
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
2. Identify the text area where the user will type.
3. Based on the full visual context, generate the text the user most likely wants to write.
Key behavior:
- If the text area is EMPTY, draft a full response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
- If the text area already has text, continue it naturally.
Rules:
- Output ONLY the text to be inserted. No quotes, no explanations, no meta-commentary.
- Be concise but complete a full thought, not a fragment.
- Match the tone and formality of the surrounding context.
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
- Do NOT describe the screenshot or explain your reasoning.
- If you cannot determine what to write, output nothing."""
APP_CONTEXT_BLOCK = """
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
KB_CONTEXT_BLOCK = """
You also have access to the user's knowledge base documents below. Use them to write more accurate, informed, and contextually relevant text. Do NOT cite or reference the documents explicitly — just let the knowledge inform your writing naturally.
<knowledge_base>
{kb_context}
</knowledge_base>"""
PREP_STEP_ID = "autocomplete-prep"
def _build_system_prompt(app_name: str, window_title: str, kb_context: str) -> str:
"""Assemble the system prompt from optional context blocks."""
prompt = VISION_SYSTEM_PROMPT
if app_name:
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
if kb_context:
prompt += KB_CONTEXT_BLOCK.format(kb_context=kb_context)
return prompt
def _derive_kb_query(app_name: str, window_title: str) -> str:
parts = [p for p in (window_title, app_name) if p]
return " ".join(parts)
def _is_vision_unsupported_error(e: Exception) -> bool:
"""Check if an exception indicates the model doesn't support vision/images."""
msg = str(e).lower()
return "content must be a string" in msg or "does not support image" in msg
async def _extract_query_from_screenshot(
llm, screenshot_data_url: str,
app_name: str = "", window_title: str = "",
) -> str | None:
"""Ask the Vision LLM to describe what the user is working on.
Raises vision-unsupported errors so the caller can return a
friendly message immediately instead of retrying with astream.
"""
if app_name:
prompt_text = EXTRACT_QUERY_PROMPT_WITH_APP.format(
app_name=app_name, window_title=window_title,
)
else:
prompt_text = EXTRACT_QUERY_PROMPT
try:
response = await llm.ainvoke([
HumanMessage(content=[
{"type": "text", "text": prompt_text},
{"type": "image_url", "image_url": {"url": screenshot_data_url}},
]),
])
query = response.content.strip() if hasattr(response, "content") else ""
return query if query else None
except Exception as e:
if _is_vision_unsupported_error(e):
raise
logger.warning(f"Failed to extract query from screenshot: {e}")
return None
async def _search_knowledge_base(
session: AsyncSession, search_space_id: int, query: str
) -> str:
"""Search the KB and return formatted context string."""
try:
retriever = ChucksHybridSearchRetriever(session)
results = await retriever.hybrid_search(
query_text=query,
top_k=KB_TOP_K,
search_space_id=search_space_id,
)
if not results:
return ""
parts: list[str] = []
char_count = 0
for doc in results:
title = doc.get("document", {}).get("title", "Untitled")
for chunk in doc.get("chunks", []):
content = chunk.get("content", "").strip()
if not content:
continue
entry = f"[{title}]\n{content}"
if char_count + len(entry) > KB_MAX_CHARS:
break
parts.append(entry)
char_count += len(entry)
if char_count >= KB_MAX_CHARS:
break
return "\n\n---\n\n".join(parts)
except Exception as e:
logger.warning(f"KB search failed, proceeding without context: {e}")
return ""
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
async def stream_vision_autocomplete(
@ -144,13 +45,7 @@ async def stream_vision_autocomplete(
app_name: str = "",
window_title: str = "",
) -> AsyncGenerator[str, None]:
"""Analyze a screenshot with the vision LLM and stream a text completion.
Pipeline:
1. Extract a search query from the screenshot (non-streaming)
2. Search the knowledge base for relevant context
3. Stream the final completion with screenshot + KB + app context
"""
"""Analyze a screenshot with a vision-LLM agent and stream a text completion."""
streaming = VercelStreamingService()
vision_error_msg = (
"The selected model does not support vision. "
@ -164,62 +59,100 @@ async def stream_vision_autocomplete(
yield streaming.format_done()
return
kb_context = ""
# Start SSE stream immediately so the UI has something to show
yield streaming.format_message_start()
kb_query = _derive_kb_query(app_name, window_title)
# Show a preparation step while KB search + agent compile run
yield streaming.format_thinking_step(
step_id=PREP_STEP_ID,
title="Searching knowledge base",
status="in_progress",
items=[kb_query] if kb_query else [],
)
try:
query = await _extract_query_from_screenshot(
llm, screenshot_data_url, app_name=app_name, window_title=window_title,
agent, kb = await create_autocomplete_agent(
llm,
search_space_id=search_space_id,
kb_query=kb_query,
app_name=app_name,
window_title=window_title,
)
except Exception as e:
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
yield streaming.format_message_start()
yield streaming.format_error(vision_error_msg)
if _is_vision_unsupported_error(e):
logger.warning("Vision autocomplete: model does not support vision: %s", e)
yield streaming.format_error(vision_error_msg)
yield streaming.format_done()
return
logger.error("Failed to create autocomplete agent: %s", e, exc_info=True)
yield streaming.format_error("Autocomplete failed. Please try again.")
yield streaming.format_done()
return
if query:
kb_context = await _search_knowledge_base(session, search_space_id, query)
has_kb = kb.has_documents
doc_count = len(kb.files) if has_kb else 0 # type: ignore[arg-type]
system_prompt = _build_system_prompt(app_name, window_title, kb_context)
yield streaming.format_thinking_step(
step_id=PREP_STEP_ID,
title="Searching knowledge base",
status="complete",
items=[f"Found {doc_count} document{'s' if doc_count != 1 else ''}"]
if kb_query
else ["Skipped"],
)
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=[
{
"type": "text",
"text": "Analyze this screenshot. Understand the full context of what the user is working on, then generate the text they most likely want to write in the active text area.",
},
{
"type": "image_url",
"image_url": {"url": screenshot_data_url},
},
]),
]
# Build agent input with pre-computed KB as initial state
if has_kb:
instruction = (
"Analyze this screenshot, then explore the knowledge base documents "
"listed above — read the chunk index of any document whose title "
"looks relevant and check matched chunks for useful facts. "
"Finally, generate a concise autocomplete for the active text area, "
"enhanced with any relevant KB information you found."
)
else:
instruction = (
"Analyze this screenshot and generate a concise autocomplete "
"for the active text area based on what you see."
)
text_started = False
text_id = ""
user_message = HumanMessage(
content=[
{"type": "text", "text": instruction},
{"type": "image_url", "image_url": {"url": screenshot_data_url}},
]
)
input_data: dict = {"messages": [user_message]}
if has_kb:
input_data["files"] = kb.files
input_data["messages"] = [kb.ls_ai_msg, kb.ls_tool_msg, user_message]
logger.info(
"Autocomplete: injected %d KB files into agent initial state", doc_count
)
else:
logger.info(
"Autocomplete: no KB documents found, proceeding with screenshot only"
)
# Stream the agent (message_start already sent above)
try:
yield streaming.format_message_start()
text_id = streaming.generate_text_id()
yield streaming.format_text_start(text_id)
text_started = True
async for chunk in llm.astream(messages):
token = chunk.content if hasattr(chunk, "content") else str(chunk)
if token:
yield streaming.format_text_delta(text_id, token)
yield streaming.format_text_end(text_id)
yield streaming.format_finish()
yield streaming.format_done()
async for sse in stream_autocomplete_agent(
agent,
input_data,
streaming,
emit_message_start=False,
):
yield sse
except Exception as e:
if text_started:
yield streaming.format_text_end(text_id)
if _is_vision_unsupported_error(e):
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
logger.warning("Vision autocomplete: model does not support vision: %s", e)
yield streaming.format_error(vision_error_msg)
yield streaming.format_done()
else:
logger.error(f"Vision autocomplete streaming error: {e}", exc_info=True)
logger.error("Vision autocomplete streaming error: %s", e, exc_info=True)
yield streaming.format_error("Autocomplete failed. Please try again.")
yield streaming.format_done()
yield streaming.format_done()

View file

@ -51,7 +51,10 @@ async def _should_skip_file(
file_id = file.get("id", "")
file_name = file.get("name", "Unknown")
if skip_item(file):
skip, unsup_ext = skip_item(file)
if skip:
if unsup_ext:
return True, f"unsupported:{unsup_ext}"
return True, "folder/non-downloadable"
if not file_id:
return True, "missing file_id"
@ -251,6 +254,121 @@ async def _download_and_index(
return batch_indexed, download_failed + batch_failed
async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int):
"""Remove a document that was deleted in Dropbox."""
primary_hash = compute_identifier_hash(
DocumentType.DROPBOX_FILE.value, file_id, search_space_id
)
existing = await check_document_by_unique_identifier(session, primary_hash)
if not existing:
result = await session.execute(
select(Document).where(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
cast(Document.document_metadata["dropbox_file_id"], String) == file_id,
)
)
existing = result.scalar_one_or_none()
if existing:
await session.delete(existing)
async def _index_with_delta_sync(
dropbox_client: DropboxClient,
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
cursor: str,
task_logger: TaskLoggingService,
log_entry: object,
max_files: int,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
enable_summary: bool = True,
) -> tuple[int, int, int, str]:
"""Delta sync using Dropbox cursor-based change tracking.
Returns (indexed_count, skipped_count, new_cursor).
"""
await task_logger.log_task_progress(
log_entry,
f"Starting delta sync from cursor: {cursor[:20]}...",
{"stage": "delta_sync", "cursor_prefix": cursor[:20]},
)
entries, new_cursor, error = await dropbox_client.get_changes(cursor)
if error:
err_lower = error.lower()
if "401" in error or "authentication expired" in err_lower:
raise Exception(
f"Dropbox authentication failed. Please re-authenticate. (Error: {error})"
)
raise Exception(f"Failed to fetch Dropbox changes: {error}")
if not entries:
logger.info("No changes detected since last sync")
return 0, 0, 0, new_cursor or cursor
logger.info(f"Processing {len(entries)} change entries")
renamed_count = 0
skipped = 0
unsupported_count = 0
files_to_download: list[dict] = []
files_processed = 0
for entry in entries:
if files_processed >= max_files:
break
files_processed += 1
tag = entry.get(".tag")
if tag == "deleted":
path_lower = entry.get("path_lower", "")
name = entry.get("name", "")
file_id = entry.get("id", "")
if file_id:
await _remove_document(session, file_id, search_space_id)
logger.debug(f"Processed deletion: {name or path_lower}")
continue
if tag != "file":
continue
skip, msg = await _should_skip_file(session, entry, search_space_id)
if skip:
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
continue
files_to_download.append(entry)
batch_indexed, failed = await _download_and_index(
dropbox_client,
session,
files_to_download,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=enable_summary,
on_heartbeat=on_heartbeat_callback,
)
indexed = renamed_count + batch_indexed
logger.info(
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
f"{unsupported_count} unsupported, {failed} failed"
)
return indexed, skipped, unsupported_count, new_cursor or cursor
async def _index_full_scan(
dropbox_client: DropboxClient,
session: AsyncSession,
@ -266,8 +384,11 @@ async def _index_full_scan(
incremental_sync: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
enable_summary: bool = True,
) -> tuple[int, int]:
"""Full scan indexing of a folder."""
) -> tuple[int, int, int]:
"""Full scan indexing of a folder.
Returns (indexed, skipped, unsupported_count).
"""
await task_logger.log_task_progress(
log_entry,
f"Starting full scan of folder: {folder_name}",
@ -287,6 +408,7 @@ async def _index_full_scan(
renamed_count = 0
skipped = 0
unsupported_count = 0
files_to_download: list[dict] = []
all_files, error = await get_files_in_folder(
@ -306,14 +428,21 @@ async def _index_full_scan(
if incremental_sync:
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
continue
elif skip_item(file):
skipped += 1
continue
else:
item_skip, item_unsup = skip_item(file)
if item_skip:
if item_unsup:
unsupported_count += 1
else:
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
@ -352,9 +481,10 @@ async def _index_full_scan(
indexed = renamed_count + batch_indexed
logger.info(
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
f"{unsupported_count} unsupported, {failed} failed"
)
return indexed, skipped
return indexed, skipped, unsupported_count
async def _index_selected_files(
@ -368,7 +498,7 @@ async def _index_selected_files(
enable_summary: bool,
incremental_sync: bool = True,
on_heartbeat: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
) -> tuple[int, int, int, list[str]]:
"""Index user-selected files using the parallel pipeline."""
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
@ -379,6 +509,7 @@ async def _index_selected_files(
errors: list[str] = []
renamed_count = 0
skipped = 0
unsupported_count = 0
for file_path, file_name in file_paths:
file, error = await get_file_by_path(dropbox_client, file_path)
@ -390,14 +521,21 @@ async def _index_selected_files(
if incremental_sync:
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
continue
elif skip_item(file):
skipped += 1
continue
else:
item_skip, item_unsup = skip_item(file)
if item_skip:
if item_unsup:
unsupported_count += 1
else:
skipped += 1
continue
file_pages = PageLimitService.estimate_pages_from_metadata(
file.get("name", ""), file.get("size")
@ -429,7 +567,7 @@ async def _index_selected_files(
user_id, pages_to_deduct, allow_exceed=True
)
return renamed_count + batch_indexed, skipped, errors
return renamed_count + batch_indexed, skipped, unsupported_count, errors
async def index_dropbox_files(
@ -438,7 +576,7 @@ async def index_dropbox_files(
search_space_id: int,
user_id: str,
items_dict: dict,
) -> tuple[int, int, str | None]:
) -> tuple[int, int, str | None, int]:
"""Index Dropbox files for a specific connector.
items_dict format:
@ -469,7 +607,7 @@ async def index_dropbox_files(
await task_logger.log_task_failure(
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
)
return 0, 0, error_msg
return 0, 0, error_msg, 0
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
@ -480,7 +618,7 @@ async def index_dropbox_files(
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return 0, 0, error_msg
return 0, 0, error_msg, 0
connector_enable_summary = getattr(connector, "enable_summary", True)
dropbox_client = DropboxClient(session, connector_id)
@ -489,9 +627,13 @@ async def index_dropbox_files(
max_files = indexing_options.get("max_files", 500)
incremental_sync = indexing_options.get("incremental_sync", True)
include_subfolders = indexing_options.get("include_subfolders", True)
use_delta_sync = indexing_options.get("use_delta_sync", True)
folder_cursors: dict = connector.config.get("folder_cursors", {})
total_indexed = 0
total_skipped = 0
total_unsupported = 0
selected_files = items_dict.get("files", [])
if selected_files:
@ -499,7 +641,7 @@ async def index_dropbox_files(
(f.get("path", f.get("path_lower", f.get("id", ""))), f.get("name"))
for f in selected_files
]
indexed, skipped, file_errors = await _index_selected_files(
indexed, skipped, unsupported, file_errors = await _index_selected_files(
dropbox_client,
session,
file_tuples,
@ -511,6 +653,7 @@ async def index_dropbox_files(
)
total_indexed += indexed
total_skipped += skipped
total_unsupported += unsupported
if file_errors:
logger.warning(
f"File indexing errors for connector {connector_id}: {file_errors}"
@ -523,25 +666,66 @@ async def index_dropbox_files(
)
folder_name = folder.get("name", "Root")
logger.info(f"Using full scan for folder {folder_name}")
indexed, skipped = await _index_full_scan(
dropbox_client,
session,
connector_id,
search_space_id,
user_id,
folder_path,
folder_name,
task_logger,
log_entry,
max_files,
include_subfolders,
incremental_sync=incremental_sync,
enable_summary=connector_enable_summary,
saved_cursor = folder_cursors.get(folder_path)
can_use_delta = (
use_delta_sync and saved_cursor and connector.last_indexed_at
)
if can_use_delta:
logger.info(f"Using delta sync for folder {folder_name}")
indexed, skipped, unsup, new_cursor = await _index_with_delta_sync(
dropbox_client,
session,
connector_id,
search_space_id,
user_id,
saved_cursor,
task_logger,
log_entry,
max_files,
enable_summary=connector_enable_summary,
)
folder_cursors[folder_path] = new_cursor
total_unsupported += unsup
else:
logger.info(f"Using full scan for folder {folder_name}")
indexed, skipped, unsup = await _index_full_scan(
dropbox_client,
session,
connector_id,
search_space_id,
user_id,
folder_path,
folder_name,
task_logger,
log_entry,
max_files,
include_subfolders,
incremental_sync=incremental_sync,
enable_summary=connector_enable_summary,
)
total_unsupported += unsup
total_indexed += indexed
total_skipped += skipped
# Persist latest cursor for this folder
try:
latest_cursor, cursor_err = await dropbox_client.get_latest_cursor(
folder_path
)
if latest_cursor and not cursor_err:
folder_cursors[folder_path] = latest_cursor
except Exception as e:
logger.warning(f"Failed to get latest cursor for {folder_path}: {e}")
# Persist folder cursors to connector config
if folders:
cfg = dict(connector.config)
cfg["folder_cursors"] = folder_cursors
connector.config = cfg
flag_modified(connector, "config")
if total_indexed > 0 or folders:
await update_connector_last_indexed(session, connector, True)
@ -550,12 +734,18 @@ async def index_dropbox_files(
await task_logger.log_task_success(
log_entry,
f"Successfully completed Dropbox indexing for connector {connector_id}",
{"files_processed": total_indexed, "files_skipped": total_skipped},
{
"files_processed": total_indexed,
"files_skipped": total_skipped,
"files_unsupported": total_unsupported,
},
)
logger.info(
f"Dropbox indexing completed: {total_indexed} indexed, {total_skipped} skipped"
f"Dropbox indexing completed: {total_indexed} indexed, "
f"{total_skipped} skipped, {total_unsupported} unsupported"
)
return total_indexed, total_skipped, None
return total_indexed, total_skipped, None, total_unsupported
except SQLAlchemyError as db_error:
await session.rollback()
@ -566,7 +756,7 @@ async def index_dropbox_files(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}", 0
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -576,4 +766,4 @@ async def index_dropbox_files(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Dropbox files: {e!s}", exc_info=True)
return 0, 0, f"Failed to index Dropbox files: {e!s}"
return 0, 0, f"Failed to index Dropbox files: {e!s}", 0

View file

@ -25,7 +25,11 @@ from app.connectors.google_drive import (
get_files_in_folder,
get_start_page_token,
)
from app.connectors.google_drive.file_types import should_skip_file as skip_mime
from app.connectors.google_drive.file_types import (
is_google_workspace_file,
should_skip_by_extension,
should_skip_file as skip_mime,
)
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
@ -78,6 +82,10 @@ async def _should_skip_file(
if skip_mime(mime_type):
return True, "folder/shortcut"
if not is_google_workspace_file(mime_type):
ext_skip, unsup_ext = should_skip_by_extension(file_name)
if ext_skip:
return True, f"unsupported:{unsup_ext}"
if not file_id:
return True, "missing file_id"
@ -468,13 +476,13 @@ async def _index_selected_files(
user_id: str,
enable_summary: bool,
on_heartbeat: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
) -> tuple[int, int, int, list[str]]:
"""Index user-selected files using the parallel pipeline.
Phase 1 (serial): fetch metadata + skip checks.
Phase 2+3 (parallel): download, ETL, index via _download_and_index.
Returns (indexed_count, skipped_count, errors).
Returns (indexed_count, skipped_count, unsupported_count, errors).
"""
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
@ -485,6 +493,7 @@ async def _index_selected_files(
errors: list[str] = []
renamed_count = 0
skipped = 0
unsupported_count = 0
for file_id, file_name in file_ids:
file, error = await get_file_by_id(drive_client, file_id)
@ -495,7 +504,9 @@ async def _index_selected_files(
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
@ -539,7 +550,7 @@ async def _index_selected_files(
user_id, pages_to_deduct, allow_exceed=True
)
return renamed_count + batch_indexed, skipped, errors
return renamed_count + batch_indexed, skipped, unsupported_count, errors
# ---------------------------------------------------------------------------
@ -562,8 +573,11 @@ async def _index_full_scan(
include_subfolders: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
enable_summary: bool = True,
) -> tuple[int, int]:
"""Full scan indexing of a folder."""
) -> tuple[int, int, int]:
"""Full scan indexing of a folder.
Returns (indexed, skipped, unsupported_count).
"""
await task_logger.log_task_progress(
log_entry,
f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})",
@ -585,6 +599,7 @@ async def _index_full_scan(
renamed_count = 0
skipped = 0
unsupported_count = 0
files_processed = 0
files_to_download: list[dict] = []
folders_to_process = [(folder_id, folder_name)]
@ -625,7 +640,9 @@ async def _index_full_scan(
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
@ -698,9 +715,10 @@ async def _index_full_scan(
indexed = renamed_count + batch_indexed
logger.info(
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
f"{unsupported_count} unsupported, {failed} failed"
)
return indexed, skipped
return indexed, skipped, unsupported_count
async def _index_with_delta_sync(
@ -718,8 +736,11 @@ async def _index_with_delta_sync(
include_subfolders: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
enable_summary: bool = True,
) -> tuple[int, int]:
"""Delta sync using change tracking."""
) -> tuple[int, int, int]:
"""Delta sync using change tracking.
Returns (indexed, skipped, unsupported_count).
"""
await task_logger.log_task_progress(
log_entry,
f"Starting delta sync from token: {start_page_token[:20]}...",
@ -739,7 +760,7 @@ async def _index_with_delta_sync(
if not changes:
logger.info("No changes detected since last sync")
return 0, 0
return 0, 0, 0
logger.info(f"Processing {len(changes)} changes")
@ -754,6 +775,7 @@ async def _index_with_delta_sync(
renamed_count = 0
skipped = 0
unsupported_count = 0
files_to_download: list[dict] = []
files_processed = 0
@ -775,7 +797,9 @@ async def _index_with_delta_sync(
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
@ -832,9 +856,10 @@ async def _index_with_delta_sync(
indexed = renamed_count + batch_indexed
logger.info(
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
f"{unsupported_count} unsupported, {failed} failed"
)
return indexed, skipped
return indexed, skipped, unsupported_count
# ---------------------------------------------------------------------------
@ -854,8 +879,11 @@ async def index_google_drive_files(
max_files: int = 500,
include_subfolders: bool = False,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, int, str | None]:
"""Index Google Drive files for a specific connector."""
) -> tuple[int, int, str | None, int]:
"""Index Google Drive files for a specific connector.
Returns (indexed, skipped, error_or_none, unsupported_count).
"""
task_logger = TaskLoggingService(session, search_space_id)
log_entry = await task_logger.log_task_start(
task_name="google_drive_files_indexing",
@ -881,7 +909,7 @@ async def index_google_drive_files(
await task_logger.log_task_failure(
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
)
return 0, 0, error_msg
return 0, 0, error_msg, 0
await task_logger.log_task_progress(
log_entry,
@ -900,7 +928,7 @@ async def index_google_drive_files(
"Missing Composio account",
{"error_type": "MissingComposioAccount"},
)
return 0, 0, error_msg
return 0, 0, error_msg, 0
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
@ -915,6 +943,7 @@ async def index_google_drive_files(
0,
0,
"SECRET_KEY not configured but credentials are marked as encrypted",
0,
)
connector_enable_summary = getattr(connector, "enable_summary", True)
@ -927,7 +956,7 @@ async def index_google_drive_files(
await task_logger.log_task_failure(
log_entry, error_msg, {"error_type": "MissingParameter"}
)
return 0, 0, error_msg
return 0, 0, error_msg, 0
target_folder_id = folder_id
target_folder_name = folder_name or "Selected Folder"
@ -938,9 +967,11 @@ async def index_google_drive_files(
use_delta_sync and start_page_token and connector.last_indexed_at
)
documents_unsupported = 0
if can_use_delta:
logger.info(f"Using delta sync for connector {connector_id}")
documents_indexed, documents_skipped = await _index_with_delta_sync(
documents_indexed, documents_skipped, du = await _index_with_delta_sync(
drive_client,
session,
connector,
@ -956,8 +987,9 @@ async def index_google_drive_files(
on_heartbeat_callback,
connector_enable_summary,
)
documents_unsupported += du
logger.info("Running reconciliation scan after delta sync")
ri, rs = await _index_full_scan(
ri, rs, ru = await _index_full_scan(
drive_client,
session,
connector,
@ -975,9 +1007,14 @@ async def index_google_drive_files(
)
documents_indexed += ri
documents_skipped += rs
documents_unsupported += ru
else:
logger.info(f"Using full scan for connector {connector_id}")
documents_indexed, documents_skipped = await _index_full_scan(
(
documents_indexed,
documents_skipped,
documents_unsupported,
) = await _index_full_scan(
drive_client,
session,
connector,
@ -1012,14 +1049,17 @@ async def index_google_drive_files(
{
"files_processed": documents_indexed,
"files_skipped": documents_skipped,
"files_unsupported": documents_unsupported,
"sync_type": "delta" if can_use_delta else "full",
"folder": target_folder_name,
},
)
logger.info(
f"Google Drive indexing completed: {documents_indexed} indexed, {documents_skipped} skipped"
f"Google Drive indexing completed: {documents_indexed} indexed, "
f"{documents_skipped} skipped, {documents_unsupported} unsupported"
)
return documents_indexed, documents_skipped, None
return documents_indexed, documents_skipped, None, documents_unsupported
except SQLAlchemyError as db_error:
await session.rollback()
@ -1030,7 +1070,7 @@ async def index_google_drive_files(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}", 0
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -1040,7 +1080,7 @@ async def index_google_drive_files(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Google Drive files: {e!s}", exc_info=True)
return 0, 0, f"Failed to index Google Drive files: {e!s}"
return 0, 0, f"Failed to index Google Drive files: {e!s}", 0
async def index_google_drive_single_file(
@ -1242,7 +1282,7 @@ async def index_google_drive_selected_files(
session, connector_id, credentials=pre_built_credentials
)
indexed, skipped, errors = await _index_selected_files(
indexed, skipped, unsupported, errors = await _index_selected_files(
drive_client,
session,
files,
@ -1253,6 +1293,11 @@ async def index_google_drive_selected_files(
on_heartbeat=on_heartbeat_callback,
)
if unsupported > 0:
file_text = "file was" if unsupported == 1 else "files were"
unsup_msg = f"{unsupported} {file_text} not supported"
errors.append(unsup_msg)
await session.commit()
if errors:
@ -1260,7 +1305,12 @@ async def index_google_drive_selected_files(
log_entry,
f"Batch file indexing completed with {len(errors)} error(s)",
"; ".join(errors),
{"indexed": indexed, "skipped": skipped, "error_count": len(errors)},
{
"indexed": indexed,
"skipped": skipped,
"unsupported": unsupported,
"error_count": len(errors),
},
)
else:
await task_logger.log_task_success(

View file

@ -23,7 +23,6 @@ from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import (
Document,
DocumentStatus,
@ -44,132 +43,6 @@ from .base import (
logger,
)
PLAINTEXT_EXTENSIONS = frozenset(
{
".md",
".markdown",
".txt",
".text",
".json",
".jsonl",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".xml",
".css",
".scss",
".less",
".sass",
".py",
".pyw",
".pyi",
".pyx",
".js",
".jsx",
".ts",
".tsx",
".mjs",
".cjs",
".java",
".kt",
".kts",
".scala",
".groovy",
".c",
".h",
".cpp",
".cxx",
".cc",
".hpp",
".hxx",
".cs",
".fs",
".fsx",
".go",
".rs",
".rb",
".php",
".pl",
".pm",
".lua",
".swift",
".m",
".mm",
".r",
".R",
".jl",
".sh",
".bash",
".zsh",
".fish",
".bat",
".cmd",
".ps1",
".sql",
".graphql",
".gql",
".env",
".gitignore",
".dockerignore",
".editorconfig",
".makefile",
".cmake",
".log",
".rst",
".tex",
".bib",
".org",
".adoc",
".asciidoc",
".vue",
".svelte",
".astro",
".tf",
".hcl",
".proto",
}
)
AUDIO_EXTENSIONS = frozenset(
{
".mp3",
".mp4",
".mpeg",
".mpga",
".m4a",
".wav",
".webm",
}
)
DIRECT_CONVERT_EXTENSIONS = frozenset({".csv", ".tsv", ".html", ".htm"})
def _is_plaintext_file(filename: str) -> bool:
return Path(filename).suffix.lower() in PLAINTEXT_EXTENSIONS
def _is_audio_file(filename: str) -> bool:
return Path(filename).suffix.lower() in AUDIO_EXTENSIONS
def _is_direct_convert_file(filename: str) -> bool:
return Path(filename).suffix.lower() in DIRECT_CONVERT_EXTENSIONS
def _needs_etl(filename: str) -> bool:
"""File is not plaintext, not audio, and not direct-convert — requires ETL."""
return (
not _is_plaintext_file(filename)
and not _is_audio_file(filename)
and not _is_direct_convert_file(filename)
)
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
@ -279,57 +152,19 @@ def scan_folder(
return files
def _read_plaintext_file(file_path: str) -> str:
"""Read a plaintext/text-based file as UTF-8."""
with open(file_path, encoding="utf-8", errors="replace") as f:
content = f.read()
if "\x00" in content:
raise ValueError(
f"File contains null bytes — likely a binary file opened as text: {file_path}"
)
return content
async def _read_file_content(file_path: str, filename: str) -> str:
"""Read file content, using ETL for binary formats.
"""Read file content via the unified ETL pipeline.
Plaintext files are read directly. Audio and document files (PDF, DOCX, etc.)
are routed through the configured ETL service (same as Google Drive / OneDrive).
Raises ValueError if the file cannot be parsed (e.g. no ETL service configured
for a binary file).
All file types (plaintext, audio, direct-convert, document) are handled
by ``EtlPipelineService``.
"""
if _is_plaintext_file(filename):
return _read_plaintext_file(file_path)
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
if _is_direct_convert_file(filename):
from app.tasks.document_processors._direct_converters import (
convert_file_directly,
)
return convert_file_directly(file_path, filename)
if _is_audio_file(filename):
etl_service = config.ETL_SERVICE if hasattr(config, "ETL_SERVICE") else None
stt_service_val = config.STT_SERVICE if hasattr(config, "STT_SERVICE") else None
if not stt_service_val and not etl_service:
raise ValueError(
f"No STT_SERVICE configured — cannot transcribe audio file: {filename}"
)
if _needs_etl(filename):
etl_service = getattr(config, "ETL_SERVICE", None)
if not etl_service:
raise ValueError(
f"No ETL_SERVICE configured — cannot parse binary file: {filename}. "
f"Set ETL_SERVICE to UNSTRUCTURED, LLAMACLOUD, or DOCLING in your .env"
)
from app.connectors.onedrive.content_extractor import (
_parse_file_to_markdown,
result = await EtlPipelineService().extract(
EtlRequest(file_path=file_path, filename=filename)
)
return await _parse_file_to_markdown(file_path, filename)
return result.markdown_content
def _content_hash(content: str, search_space_id: int) -> str:

View file

@ -56,7 +56,10 @@ async def _should_skip_file(
file_id = file.get("id")
file_name = file.get("name", "Unknown")
if skip_item(file):
skip, unsup_ext = skip_item(file)
if skip:
if unsup_ext:
return True, f"unsupported:{unsup_ext}"
return True, "folder/onenote/remote"
if not file_id:
return True, "missing file_id"
@ -290,7 +293,7 @@ async def _index_selected_files(
user_id: str,
enable_summary: bool,
on_heartbeat: HeartbeatCallbackType | None = None,
) -> tuple[int, int, list[str]]:
) -> tuple[int, int, int, list[str]]:
"""Index user-selected files using the parallel pipeline."""
page_limit_service = PageLimitService(session)
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
@ -301,6 +304,7 @@ async def _index_selected_files(
errors: list[str] = []
renamed_count = 0
skipped = 0
unsupported_count = 0
for file_id, file_name in file_ids:
file, error = await get_file_by_id(onedrive_client, file_id)
@ -311,7 +315,9 @@ async def _index_selected_files(
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
@ -347,7 +353,7 @@ async def _index_selected_files(
user_id, pages_to_deduct, allow_exceed=True
)
return renamed_count + batch_indexed, skipped, errors
return renamed_count + batch_indexed, skipped, unsupported_count, errors
# ---------------------------------------------------------------------------
@ -369,8 +375,11 @@ async def _index_full_scan(
include_subfolders: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
enable_summary: bool = True,
) -> tuple[int, int]:
"""Full scan indexing of a folder."""
) -> tuple[int, int, int]:
"""Full scan indexing of a folder.
Returns (indexed, skipped, unsupported_count).
"""
await task_logger.log_task_progress(
log_entry,
f"Starting full scan of folder: {folder_name}",
@ -389,6 +398,7 @@ async def _index_full_scan(
renamed_count = 0
skipped = 0
unsupported_count = 0
files_to_download: list[dict] = []
all_files, error = await get_files_in_folder(
@ -407,7 +417,9 @@ async def _index_full_scan(
for file in all_files[:max_files]:
skip, msg = await _should_skip_file(session, file, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
@ -450,9 +462,10 @@ async def _index_full_scan(
indexed = renamed_count + batch_indexed
logger.info(
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
f"{unsupported_count} unsupported, {failed} failed"
)
return indexed, skipped
return indexed, skipped, unsupported_count
async def _index_with_delta_sync(
@ -468,8 +481,11 @@ async def _index_with_delta_sync(
max_files: int,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
enable_summary: bool = True,
) -> tuple[int, int, str | None]:
"""Delta sync using OneDrive change tracking. Returns (indexed, skipped, new_delta_link)."""
) -> tuple[int, int, int, str | None]:
"""Delta sync using OneDrive change tracking.
Returns (indexed, skipped, unsupported_count, new_delta_link).
"""
await task_logger.log_task_progress(
log_entry,
"Starting delta sync",
@ -489,7 +505,7 @@ async def _index_with_delta_sync(
if not changes:
logger.info("No changes detected since last sync")
return 0, 0, new_delta_link
return 0, 0, 0, new_delta_link
logger.info(f"Processing {len(changes)} delta changes")
@ -501,6 +517,7 @@ async def _index_with_delta_sync(
renamed_count = 0
skipped = 0
unsupported_count = 0
files_to_download: list[dict] = []
files_processed = 0
@ -523,7 +540,9 @@ async def _index_with_delta_sync(
skip, msg = await _should_skip_file(session, change, search_space_id)
if skip:
if msg and "renamed" in msg.lower():
if msg and msg.startswith("unsupported:"):
unsupported_count += 1
elif msg and "renamed" in msg.lower():
renamed_count += 1
else:
skipped += 1
@ -566,9 +585,10 @@ async def _index_with_delta_sync(
indexed = renamed_count + batch_indexed
logger.info(
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
f"{unsupported_count} unsupported, {failed} failed"
)
return indexed, skipped, new_delta_link
return indexed, skipped, unsupported_count, new_delta_link
# ---------------------------------------------------------------------------
@ -582,7 +602,7 @@ async def index_onedrive_files(
search_space_id: int,
user_id: str,
items_dict: dict,
) -> tuple[int, int, str | None]:
) -> tuple[int, int, str | None, int]:
"""Index OneDrive files for a specific connector.
items_dict format:
@ -609,7 +629,7 @@ async def index_onedrive_files(
await task_logger.log_task_failure(
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
)
return 0, 0, error_msg
return 0, 0, error_msg, 0
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
@ -620,7 +640,7 @@ async def index_onedrive_files(
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return 0, 0, error_msg
return 0, 0, error_msg, 0
connector_enable_summary = getattr(connector, "enable_summary", True)
onedrive_client = OneDriveClient(session, connector_id)
@ -632,12 +652,13 @@ async def index_onedrive_files(
total_indexed = 0
total_skipped = 0
total_unsupported = 0
# Index selected individual files
selected_files = items_dict.get("files", [])
if selected_files:
file_tuples = [(f["id"], f.get("name")) for f in selected_files]
indexed, skipped, _errors = await _index_selected_files(
indexed, skipped, unsupported, _errors = await _index_selected_files(
onedrive_client,
session,
file_tuples,
@ -648,6 +669,7 @@ async def index_onedrive_files(
)
total_indexed += indexed
total_skipped += skipped
total_unsupported += unsupported
# Index selected folders
folders = items_dict.get("folders", [])
@ -661,7 +683,7 @@ async def index_onedrive_files(
if can_use_delta:
logger.info(f"Using delta sync for folder {folder_name}")
indexed, skipped, new_delta_link = await _index_with_delta_sync(
indexed, skipped, unsup, new_delta_link = await _index_with_delta_sync(
onedrive_client,
session,
connector_id,
@ -676,6 +698,7 @@ async def index_onedrive_files(
)
total_indexed += indexed
total_skipped += skipped
total_unsupported += unsup
if new_delta_link:
await session.refresh(connector)
@ -685,7 +708,7 @@ async def index_onedrive_files(
flag_modified(connector, "config")
# Reconciliation full scan
ri, rs = await _index_full_scan(
ri, rs, ru = await _index_full_scan(
onedrive_client,
session,
connector_id,
@ -701,9 +724,10 @@ async def index_onedrive_files(
)
total_indexed += ri
total_skipped += rs
total_unsupported += ru
else:
logger.info(f"Using full scan for folder {folder_name}")
indexed, skipped = await _index_full_scan(
indexed, skipped, unsup = await _index_full_scan(
onedrive_client,
session,
connector_id,
@ -719,6 +743,7 @@ async def index_onedrive_files(
)
total_indexed += indexed
total_skipped += skipped
total_unsupported += unsup
# Store new delta link for this folder
_, new_delta_link, _ = await onedrive_client.get_delta(folder_id=folder_id)
@ -737,12 +762,18 @@ async def index_onedrive_files(
await task_logger.log_task_success(
log_entry,
f"Successfully completed OneDrive indexing for connector {connector_id}",
{"files_processed": total_indexed, "files_skipped": total_skipped},
{
"files_processed": total_indexed,
"files_skipped": total_skipped,
"files_unsupported": total_unsupported,
},
)
logger.info(
f"OneDrive indexing completed: {total_indexed} indexed, {total_skipped} skipped"
f"OneDrive indexing completed: {total_indexed} indexed, "
f"{total_skipped} skipped, {total_unsupported} unsupported"
)
return total_indexed, total_skipped, None
return total_indexed, total_skipped, None, total_unsupported
except SQLAlchemyError as db_error:
await session.rollback()
@ -753,7 +784,7 @@ async def index_onedrive_files(
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, 0, f"Database error: {db_error!s}"
return 0, 0, f"Database error: {db_error!s}", 0
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
@ -763,4 +794,4 @@ async def index_onedrive_files(
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index OneDrive files: {e!s}", exc_info=True)
return 0, 0, f"Failed to index OneDrive files: {e!s}"
return 0, 0, f"Failed to index OneDrive files: {e!s}", 0

View file

@ -1,41 +1,17 @@
"""
Document processors module for background tasks.
This module provides a collection of document processors for different content types
and sources. Each processor is responsible for handling a specific type of document
processing task in the background.
Available processors:
- Extension processor: Handle documents from browser extension
- Markdown processor: Process markdown files
- File processors: Handle files using different ETL services (Unstructured, LlamaCloud, Docling)
- YouTube processor: Process YouTube videos and extract transcripts
Content extraction is handled by ``app.etl_pipeline.EtlPipelineService``.
This package keeps orchestration (save, notify, page-limit) and
non-ETL processors (extension, markdown, youtube).
"""
# Extension processor
# File processors (backward-compatible re-exports from _save)
from ._save import (
add_received_file_document_using_docling,
add_received_file_document_using_llamacloud,
add_received_file_document_using_unstructured,
)
from .extension_processor import add_extension_received_document
# Markdown processor
from .markdown_processor import add_received_markdown_file_document
# YouTube processor
from .youtube_processor import add_youtube_video_document
__all__ = [
# Extension processing
"add_extension_received_document",
# File processing with different ETL services
"add_received_file_document_using_docling",
"add_received_file_document_using_llamacloud",
"add_received_file_document_using_unstructured",
# Markdown file processing
"add_received_markdown_file_document",
# YouTube video processing
"add_youtube_video_document",
]

View file

@ -1,74 +0,0 @@
"""
Constants for file document processing.
Centralizes file type classification, LlamaCloud retry configuration,
and timeout calculation parameters.
"""
import ssl
from enum import Enum
import httpx
# ---------------------------------------------------------------------------
# File type classification
# ---------------------------------------------------------------------------
MARKDOWN_EXTENSIONS = (".md", ".markdown", ".txt")
AUDIO_EXTENSIONS = (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
DIRECT_CONVERT_EXTENSIONS = (".csv", ".tsv", ".html", ".htm")
class FileCategory(Enum):
MARKDOWN = "markdown"
AUDIO = "audio"
DIRECT_CONVERT = "direct_convert"
DOCUMENT = "document"
def classify_file(filename: str) -> FileCategory:
"""Classify a file by its extension into a processing category."""
lower = filename.lower()
if lower.endswith(MARKDOWN_EXTENSIONS):
return FileCategory.MARKDOWN
if lower.endswith(AUDIO_EXTENSIONS):
return FileCategory.AUDIO
if lower.endswith(DIRECT_CONVERT_EXTENSIONS):
return FileCategory.DIRECT_CONVERT
return FileCategory.DOCUMENT
# ---------------------------------------------------------------------------
# LlamaCloud retry configuration
# ---------------------------------------------------------------------------
LLAMACLOUD_MAX_RETRIES = 5
LLAMACLOUD_BASE_DELAY = 10 # seconds (exponential backoff base)
LLAMACLOUD_MAX_DELAY = 120 # max delay between retries (2 minutes)
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
ssl.SSLError,
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadError,
httpx.ReadTimeout,
httpx.WriteError,
httpx.WriteTimeout,
httpx.RemoteProtocolError,
httpx.LocalProtocolError,
ConnectionError,
ConnectionResetError,
TimeoutError,
OSError,
)
# ---------------------------------------------------------------------------
# Timeout calculation constants
# ---------------------------------------------------------------------------
UPLOAD_BYTES_PER_SECOND_SLOW = (
100 * 1024
) # 100 KB/s (conservative for slow connections)
MIN_UPLOAD_TIMEOUT = 120 # Minimum 2 minutes for any file
MAX_UPLOAD_TIMEOUT = 1800 # Maximum 30 minutes for very large files
BASE_JOB_TIMEOUT = 600 # 10 minutes base for job processing
PER_PAGE_JOB_TIMEOUT = 60 # 1 minute per page for processing

View file

@ -4,8 +4,8 @@ Lossless file-to-markdown converters for text-based formats.
These converters handle file types that can be faithfully represented as
markdown without any external ETL/OCR service:
- CSV / TSV markdown table (stdlib ``csv``)
- HTML / HTM markdown (``markdownify``)
- CSV / TSV markdown table (stdlib ``csv``)
- HTML / HTM / XHTML markdown (``markdownify``)
"""
from __future__ import annotations
@ -73,6 +73,7 @@ _CONVERTER_MAP: dict[str, Callable[..., str]] = {
".tsv": tsv_to_markdown,
".html": html_to_markdown,
".htm": html_to_markdown,
".xhtml": html_to_markdown,
}

View file

@ -1,209 +0,0 @@
"""
ETL parsing strategies for different document processing services.
Provides parse functions for Unstructured, LlamaCloud, and Docling, along with
LlamaCloud retry logic and dynamic timeout calculations.
"""
import asyncio
import logging
import os
import random
import warnings
from logging import ERROR, getLogger
import httpx
from app.config import config as app_config
from app.db import Log
from app.services.task_logging_service import TaskLoggingService
from ._constants import (
LLAMACLOUD_BASE_DELAY,
LLAMACLOUD_MAX_DELAY,
LLAMACLOUD_MAX_RETRIES,
LLAMACLOUD_RETRYABLE_EXCEPTIONS,
PER_PAGE_JOB_TIMEOUT,
)
from ._helpers import calculate_job_timeout, calculate_upload_timeout
# ---------------------------------------------------------------------------
# LlamaCloud parsing with retry
# ---------------------------------------------------------------------------
async def parse_with_llamacloud_retry(
file_path: str,
estimated_pages: int,
task_logger: TaskLoggingService | None = None,
log_entry: Log | None = None,
):
"""
Parse a file with LlamaCloud with retry logic for transient SSL/connection errors.
Uses dynamic timeout calculations based on file size and page count to handle
very large files reliably.
Returns:
LlamaParse result object
Raises:
Exception: If all retries fail
"""
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
file_size_bytes = os.path.getsize(file_path)
file_size_mb = file_size_bytes / (1024 * 1024)
upload_timeout = calculate_upload_timeout(file_size_bytes)
job_timeout = calculate_job_timeout(estimated_pages, file_size_bytes)
custom_timeout = httpx.Timeout(
connect=120.0,
read=upload_timeout,
write=upload_timeout,
pool=120.0,
)
logging.info(
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
f"job_timeout={job_timeout:.0f}s"
)
last_exception = None
attempt_errors: list[str] = []
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
try:
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1,
verbose=True,
language="en",
result_type=ResultType.MD,
max_timeout=int(max(2000, job_timeout + upload_timeout)),
job_timeout_in_seconds=job_timeout,
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
custom_client=custom_client,
)
result = await parser.aparse(file_path)
if attempt > 1:
logging.info(
f"LlamaCloud upload succeeded on attempt {attempt} after "
f"{len(attempt_errors)} failures"
)
return result
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
last_exception = e
error_type = type(e).__name__
error_msg = str(e)[:200]
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
if attempt < LLAMACLOUD_MAX_RETRIES:
base_delay = min(
LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1)),
LLAMACLOUD_MAX_DELAY,
)
jitter = base_delay * 0.25 * (2 * random.random() - 1)
delay = base_delay + jitter
if task_logger and log_entry:
await task_logger.log_task_progress(
log_entry,
f"LlamaCloud upload failed "
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}), "
f"retrying in {delay:.0f}s",
{
"error_type": error_type,
"error_message": error_msg,
"attempt": attempt,
"retry_delay": delay,
"file_size_mb": round(file_size_mb, 1),
"upload_timeout": upload_timeout,
},
)
else:
logging.warning(
f"LlamaCloud upload failed "
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): "
f"{error_type}. File: {file_size_mb:.1f}MB. "
f"Retrying in {delay:.0f}s..."
)
await asyncio.sleep(delay)
else:
logging.error(
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} "
f"attempts. File size: {file_size_mb:.1f}MB, "
f"Pages: {estimated_pages}. "
f"Errors: {'; '.join(attempt_errors)}"
)
except Exception:
raise
raise last_exception or RuntimeError(
f"LlamaCloud parsing failed after {LLAMACLOUD_MAX_RETRIES} retries. "
f"File size: {file_size_mb:.1f}MB"
)
# ---------------------------------------------------------------------------
# Per-service parse functions
# ---------------------------------------------------------------------------
async def parse_with_unstructured(file_path: str):
"""
Parse a file using the Unstructured ETL service.
Returns:
List of LangChain Document elements.
"""
from langchain_unstructured import UnstructuredLoader
loader = UnstructuredLoader(
file_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
return await loader.aload()
async def parse_with_docling(file_path: str, filename: str) -> str:
"""
Parse a file using the Docling ETL service (via the Docling service wrapper).
Returns:
Markdown content string.
"""
from app.services.docling_service import create_docling_service
docling_service = create_docling_service()
pdfminer_logger = getLogger("pdfminer")
original_level = pdfminer_logger.level
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pdfminer")
warnings.filterwarnings(
"ignore", message=".*Cannot set gray non-stroke color.*"
)
warnings.filterwarnings("ignore", message=".*invalid float value.*")
pdfminer_logger.setLevel(ERROR)
try:
result = await docling_service.process_document(file_path, filename)
finally:
pdfminer_logger.setLevel(original_level)
return result["content"]

View file

@ -11,13 +11,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentStatus, DocumentType
from app.utils.document_converters import generate_unique_identifier_hash
from ._constants import (
BASE_JOB_TIMEOUT,
MAX_UPLOAD_TIMEOUT,
MIN_UPLOAD_TIMEOUT,
PER_PAGE_JOB_TIMEOUT,
UPLOAD_BYTES_PER_SECOND_SLOW,
)
from .base import (
check_document_by_unique_identifier,
check_duplicate_document,
@ -198,21 +191,3 @@ async def update_document_from_connector(
if "connector_id" in connector:
document.connector_id = connector["connector_id"]
await session.commit()
# ---------------------------------------------------------------------------
# Timeout calculations
# ---------------------------------------------------------------------------
def calculate_upload_timeout(file_size_bytes: int) -> float:
"""Calculate upload timeout based on file size (conservative for slow connections)."""
estimated_time = (file_size_bytes / UPLOAD_BYTES_PER_SECOND_SLOW) * 1.5
return max(MIN_UPLOAD_TIMEOUT, min(estimated_time, MAX_UPLOAD_TIMEOUT))
def calculate_job_timeout(estimated_pages: int, file_size_bytes: int) -> float:
"""Calculate job processing timeout based on page count and file size."""
page_based_timeout = BASE_JOB_TIMEOUT + (estimated_pages * PER_PAGE_JOB_TIMEOUT)
size_based_timeout = BASE_JOB_TIMEOUT + (file_size_bytes / (10 * 1024 * 1024)) * 60
return max(page_based_timeout, size_based_timeout)

View file

@ -1,14 +1,9 @@
"""
Unified document save/update logic for file processors.
Replaces the three nearly-identical ``add_received_file_document_using_*``
functions with a single ``save_file_document`` function plus thin wrappers
for backward compatibility.
"""
import logging
from langchain_core.documents import Document as LangChainDocument
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
@ -207,79 +202,3 @@ async def save_file_document(
raise RuntimeError(
f"Failed to process file document using {etl_service}: {e!s}"
) from e
# ---------------------------------------------------------------------------
# Backward-compatible wrapper functions
# ---------------------------------------------------------------------------
async def add_received_file_document_using_unstructured(
session: AsyncSession,
file_name: str,
unstructured_processed_elements: list[LangChainDocument],
search_space_id: int,
user_id: str,
connector: dict | None = None,
enable_summary: bool = True,
) -> Document | None:
"""Process and store a file document using the Unstructured service."""
from app.utils.document_converters import convert_document_to_markdown
markdown_content = await convert_document_to_markdown(
unstructured_processed_elements
)
return await save_file_document(
session,
file_name,
markdown_content,
search_space_id,
user_id,
"UNSTRUCTURED",
connector,
enable_summary,
)
async def add_received_file_document_using_llamacloud(
session: AsyncSession,
file_name: str,
llamacloud_markdown_document: str,
search_space_id: int,
user_id: str,
connector: dict | None = None,
enable_summary: bool = True,
) -> Document | None:
"""Process and store document content parsed by LlamaCloud."""
return await save_file_document(
session,
file_name,
llamacloud_markdown_document,
search_space_id,
user_id,
"LLAMACLOUD",
connector,
enable_summary,
)
async def add_received_file_document_using_docling(
session: AsyncSession,
file_name: str,
docling_markdown_document: str,
search_space_id: int,
user_id: str,
connector: dict | None = None,
enable_summary: bool = True,
) -> Document | None:
"""Process and store document content parsed by Docling."""
return await save_file_document(
session,
file_name,
docling_markdown_document,
search_space_id,
user_id,
"DOCLING",
connector,
enable_summary,
)

View file

@ -1,14 +1,8 @@
"""
File document processors orchestrating content extraction and indexing.
This module is the public entry point for file processing. It delegates to
specialised sub-modules that each own a single concern:
- ``_constants`` file type classification and configuration constants
- ``_helpers`` document deduplication, migration, connector helpers
- ``_direct_converters`` lossless file-to-markdown for csv/tsv/html
- ``_etl`` ETL parsing strategies (Unstructured, LlamaCloud, Docling)
- ``_save`` unified document creation / update logic
Delegates content extraction to ``app.etl_pipeline.EtlPipelineService`` and
keeps only orchestration concerns (notifications, logging, page limits, saving).
"""
from __future__ import annotations
@ -17,38 +11,19 @@ import contextlib
import logging
import os
from dataclasses import dataclass, field
from logging import ERROR, getLogger
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config as app_config
from app.db import Document, Log, Notification
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
from ._constants import FileCategory, classify_file
from ._direct_converters import convert_file_directly
from ._etl import (
parse_with_docling,
parse_with_llamacloud_retry,
parse_with_unstructured,
)
from ._helpers import update_document_from_connector
from ._save import (
add_received_file_document_using_docling,
add_received_file_document_using_llamacloud,
add_received_file_document_using_unstructured,
save_file_document,
)
from ._save import save_file_document
from .markdown_processor import add_received_markdown_file_document
# Re-export public API so existing ``from file_processors import …`` keeps working.
__all__ = [
"add_received_file_document_using_docling",
"add_received_file_document_using_llamacloud",
"add_received_file_document_using_unstructured",
"parse_with_llamacloud_retry",
"process_file_in_background",
"process_file_in_background_with_document",
"save_file_document",
@ -142,35 +117,31 @@ async def _log_page_divergence(
# ===================================================================
async def _process_markdown_upload(ctx: _ProcessingContext) -> Document | None:
"""Read a markdown / text file and create or update a document."""
await _notify(ctx, "parsing", "Reading file")
async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | None:
"""Extract content from a non-document file (plaintext/direct_convert/audio) via the unified ETL pipeline."""
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
await _notify(ctx, "parsing", "Processing file")
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Processing markdown/text file: {ctx.filename}",
{"file_type": "markdown", "processing_stage": "reading_file"},
f"Processing file: {ctx.filename}",
{"processing_stage": "extracting"},
)
with open(ctx.file_path, encoding="utf-8") as f:
markdown_content = f.read()
etl_result = await EtlPipelineService().extract(
EtlRequest(file_path=ctx.file_path, filename=ctx.filename)
)
with contextlib.suppress(Exception):
os.unlink(ctx.file_path)
await _notify(ctx, "chunking")
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Creating document from markdown content: {ctx.filename}",
{
"processing_stage": "creating_document",
"content_length": len(markdown_content),
},
)
result = await add_received_markdown_file_document(
ctx.session,
ctx.filename,
markdown_content,
etl_result.markdown_content,
ctx.search_space_id,
ctx.user_id,
ctx.connector,
@ -181,179 +152,19 @@ async def _process_markdown_upload(ctx: _ProcessingContext) -> Document | None:
if result:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Successfully processed markdown file: {ctx.filename}",
f"Successfully processed file: {ctx.filename}",
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "markdown",
"file_type": etl_result.content_type,
"etl_service": etl_result.etl_service,
},
)
else:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Markdown file already exists (duplicate): {ctx.filename}",
{"duplicate_detected": True, "file_type": "markdown"},
)
return result
async def _process_direct_convert_upload(ctx: _ProcessingContext) -> Document | None:
"""Convert a text-based file (csv/tsv/html) to markdown without ETL."""
await _notify(ctx, "parsing", "Converting file")
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Direct-converting file to markdown: {ctx.filename}",
{"file_type": "direct_convert", "processing_stage": "converting"},
)
markdown_content = convert_file_directly(ctx.file_path, ctx.filename)
with contextlib.suppress(Exception):
os.unlink(ctx.file_path)
await _notify(ctx, "chunking")
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Creating document from converted content: {ctx.filename}",
{
"processing_stage": "creating_document",
"content_length": len(markdown_content),
},
)
result = await add_received_markdown_file_document(
ctx.session,
ctx.filename,
markdown_content,
ctx.search_space_id,
ctx.user_id,
ctx.connector,
)
if ctx.connector:
await update_document_from_connector(result, ctx.connector, ctx.session)
if result:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Successfully direct-converted file: {ctx.filename}",
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "direct_convert",
},
)
else:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Direct-converted file already exists (duplicate): {ctx.filename}",
{"duplicate_detected": True, "file_type": "direct_convert"},
)
return result
async def _process_audio_upload(ctx: _ProcessingContext) -> Document | None:
"""Transcribe an audio file and create or update a document."""
await _notify(ctx, "parsing", "Transcribing audio")
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Processing audio file for transcription: {ctx.filename}",
{"file_type": "audio", "processing_stage": "starting_transcription"},
)
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
try:
stt_result = stt_service.transcribe_file(ctx.file_path)
transcribed_text = stt_result.get("text", "")
if not transcribed_text:
raise ValueError("Transcription returned empty text")
transcribed_text = (
f"# Transcription of {ctx.filename}\n\n{transcribed_text}"
)
except Exception as e:
raise HTTPException(
status_code=422,
detail=f"Failed to transcribe audio file {ctx.filename}: {e!s}",
) from e
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Local STT transcription completed: {ctx.filename}",
{
"processing_stage": "local_transcription_complete",
"language": stt_result.get("language"),
"confidence": stt_result.get("language_probability"),
"duration": stt_result.get("duration"),
},
)
else:
from litellm import atranscription
with open(ctx.file_path, "rb") as audio_file:
transcription_kwargs: dict = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
transcription_kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
transcription_response = await atranscription(**transcription_kwargs)
transcribed_text = transcription_response.get("text", "")
if not transcribed_text:
raise ValueError("Transcription returned empty text")
transcribed_text = f"# Transcription of {ctx.filename}\n\n{transcribed_text}"
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Transcription completed, creating document: {ctx.filename}",
{
"processing_stage": "transcription_complete",
"transcript_length": len(transcribed_text),
},
)
await _notify(ctx, "chunking")
with contextlib.suppress(Exception):
os.unlink(ctx.file_path)
result = await add_received_markdown_file_document(
ctx.session,
ctx.filename,
transcribed_text,
ctx.search_space_id,
ctx.user_id,
ctx.connector,
)
if ctx.connector:
await update_document_from_connector(result, ctx.connector, ctx.session)
if result:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Successfully transcribed and processed audio file: {ctx.filename}",
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "audio",
"transcript_length": len(transcribed_text),
"stt_service": stt_service_type,
},
)
else:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Audio file transcript already exists (duplicate): {ctx.filename}",
{"duplicate_detected": True, "file_type": "audio"},
f"File already exists (duplicate): {ctx.filename}",
{"duplicate_detected": True, "file_type": etl_result.content_type},
)
return result
@ -363,279 +174,10 @@ async def _process_audio_upload(ctx: _ProcessingContext) -> Document | None:
# ---------------------------------------------------------------------------
async def _etl_unstructured(
ctx: _ProcessingContext,
page_limit_service,
estimated_pages: int,
) -> Document | None:
"""Parse and save via the Unstructured ETL service."""
await _notify(ctx, "parsing", "Extracting content")
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Processing file with Unstructured ETL: {ctx.filename}",
{
"file_type": "document",
"etl_service": "UNSTRUCTURED",
"processing_stage": "loading",
},
)
docs = await parse_with_unstructured(ctx.file_path)
await _notify(ctx, "chunking", chunks_count=len(docs))
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Unstructured ETL completed, creating document: {ctx.filename}",
{"processing_stage": "etl_complete", "elements_count": len(docs)},
)
actual_pages = page_limit_service.estimate_pages_from_elements(docs)
final_pages = max(estimated_pages, actual_pages)
await _log_page_divergence(
ctx.task_logger,
ctx.log_entry,
ctx.filename,
estimated_pages,
actual_pages,
final_pages,
)
with contextlib.suppress(Exception):
os.unlink(ctx.file_path)
result = await add_received_file_document_using_unstructured(
ctx.session,
ctx.filename,
docs,
ctx.search_space_id,
ctx.user_id,
ctx.connector,
enable_summary=ctx.enable_summary,
)
if ctx.connector:
await update_document_from_connector(result, ctx.connector, ctx.session)
if result:
await page_limit_service.update_page_usage(
ctx.user_id, final_pages, allow_exceed=True
)
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Successfully processed file with Unstructured: {ctx.filename}",
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "document",
"etl_service": "UNSTRUCTURED",
"pages_processed": final_pages,
},
)
else:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Document already exists (duplicate): {ctx.filename}",
{
"duplicate_detected": True,
"file_type": "document",
"etl_service": "UNSTRUCTURED",
},
)
return result
async def _etl_llamacloud(
ctx: _ProcessingContext,
page_limit_service,
estimated_pages: int,
) -> Document | None:
"""Parse and save via the LlamaCloud ETL service."""
await _notify(ctx, "parsing", "Extracting content")
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Processing file with LlamaCloud ETL: {ctx.filename}",
{
"file_type": "document",
"etl_service": "LLAMACLOUD",
"processing_stage": "parsing",
"estimated_pages": estimated_pages,
},
)
raw_result = await parse_with_llamacloud_retry(
file_path=ctx.file_path,
estimated_pages=estimated_pages,
task_logger=ctx.task_logger,
log_entry=ctx.log_entry,
)
with contextlib.suppress(Exception):
os.unlink(ctx.file_path)
markdown_documents = await raw_result.aget_markdown_documents(split_by_page=False)
await _notify(ctx, "chunking", chunks_count=len(markdown_documents))
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"LlamaCloud parsing completed, creating documents: {ctx.filename}",
{
"processing_stage": "parsing_complete",
"documents_count": len(markdown_documents),
},
)
if not markdown_documents:
await ctx.task_logger.log_task_failure(
ctx.log_entry,
f"LlamaCloud parsing returned no documents: {ctx.filename}",
"ETL service returned empty document list",
{"error_type": "EmptyDocumentList", "etl_service": "LLAMACLOUD"},
)
raise ValueError(f"LlamaCloud parsing returned no documents for {ctx.filename}")
actual_pages = page_limit_service.estimate_pages_from_markdown(markdown_documents)
final_pages = max(estimated_pages, actual_pages)
await _log_page_divergence(
ctx.task_logger,
ctx.log_entry,
ctx.filename,
estimated_pages,
actual_pages,
final_pages,
)
any_created = False
last_doc: Document | None = None
for doc in markdown_documents:
doc_result = await add_received_file_document_using_llamacloud(
ctx.session,
ctx.filename,
llamacloud_markdown_document=doc.text,
search_space_id=ctx.search_space_id,
user_id=ctx.user_id,
connector=ctx.connector,
enable_summary=ctx.enable_summary,
)
if doc_result:
any_created = True
last_doc = doc_result
if any_created:
await page_limit_service.update_page_usage(
ctx.user_id, final_pages, allow_exceed=True
)
if ctx.connector:
await update_document_from_connector(last_doc, ctx.connector, ctx.session)
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Successfully processed file with LlamaCloud: {ctx.filename}",
{
"document_id": last_doc.id,
"content_hash": last_doc.content_hash,
"file_type": "document",
"etl_service": "LLAMACLOUD",
"pages_processed": final_pages,
"documents_count": len(markdown_documents),
},
)
return last_doc
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Document already exists (duplicate): {ctx.filename}",
{
"duplicate_detected": True,
"file_type": "document",
"etl_service": "LLAMACLOUD",
"documents_count": len(markdown_documents),
},
)
return None
async def _etl_docling(
ctx: _ProcessingContext,
page_limit_service,
estimated_pages: int,
) -> Document | None:
"""Parse and save via the Docling ETL service."""
await _notify(ctx, "parsing", "Extracting content")
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Processing file with Docling ETL: {ctx.filename}",
{
"file_type": "document",
"etl_service": "DOCLING",
"processing_stage": "parsing",
},
)
content = await parse_with_docling(ctx.file_path, ctx.filename)
with contextlib.suppress(Exception):
os.unlink(ctx.file_path)
await ctx.task_logger.log_task_progress(
ctx.log_entry,
f"Docling parsing completed, creating document: {ctx.filename}",
{"processing_stage": "parsing_complete", "content_length": len(content)},
)
actual_pages = page_limit_service.estimate_pages_from_content_length(len(content))
final_pages = max(estimated_pages, actual_pages)
await _log_page_divergence(
ctx.task_logger,
ctx.log_entry,
ctx.filename,
estimated_pages,
actual_pages,
final_pages,
)
await _notify(ctx, "chunking")
result = await add_received_file_document_using_docling(
ctx.session,
ctx.filename,
docling_markdown_document=content,
search_space_id=ctx.search_space_id,
user_id=ctx.user_id,
connector=ctx.connector,
enable_summary=ctx.enable_summary,
)
if result:
await page_limit_service.update_page_usage(
ctx.user_id, final_pages, allow_exceed=True
)
if ctx.connector:
await update_document_from_connector(result, ctx.connector, ctx.session)
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Successfully processed file with Docling: {ctx.filename}",
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "document",
"etl_service": "DOCLING",
"pages_processed": final_pages,
},
)
else:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Document already exists (duplicate): {ctx.filename}",
{
"duplicate_detected": True,
"file_type": "document",
"etl_service": "DOCLING",
},
)
return result
async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
"""Route a document file to the configured ETL service."""
"""Route a document file to the configured ETL service via the unified pipeline."""
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
from app.services.page_limit_service import PageLimitExceededError, PageLimitService
page_limit_service = PageLimitService(ctx.session)
@ -665,16 +207,60 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
os.unlink(ctx.file_path)
raise HTTPException(status_code=403, detail=str(e)) from e
etl_dispatch = {
"UNSTRUCTURED": _etl_unstructured,
"LLAMACLOUD": _etl_llamacloud,
"DOCLING": _etl_docling,
}
handler = etl_dispatch.get(app_config.ETL_SERVICE)
if handler is None:
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
await _notify(ctx, "parsing", "Extracting content")
return await handler(ctx, page_limit_service, estimated_pages)
etl_result = await EtlPipelineService().extract(
EtlRequest(
file_path=ctx.file_path,
filename=ctx.filename,
estimated_pages=estimated_pages,
)
)
with contextlib.suppress(Exception):
os.unlink(ctx.file_path)
await _notify(ctx, "chunking")
result = await save_file_document(
ctx.session,
ctx.filename,
etl_result.markdown_content,
ctx.search_space_id,
ctx.user_id,
etl_result.etl_service,
ctx.connector,
enable_summary=ctx.enable_summary,
)
if result:
await page_limit_service.update_page_usage(
ctx.user_id, estimated_pages, allow_exceed=True
)
if ctx.connector:
await update_document_from_connector(result, ctx.connector, ctx.session)
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Successfully processed file: {ctx.filename}",
{
"document_id": result.id,
"content_hash": result.content_hash,
"file_type": "document",
"etl_service": etl_result.etl_service,
"pages_processed": estimated_pages,
},
)
else:
await ctx.task_logger.log_task_success(
ctx.log_entry,
f"Document already exists (duplicate): {ctx.filename}",
{
"duplicate_detected": True,
"file_type": "document",
"etl_service": etl_result.etl_service,
},
)
return result
# ===================================================================
@ -706,15 +292,16 @@ async def process_file_in_background(
)
try:
category = classify_file(filename)
from app.etl_pipeline.file_classifier import (
FileCategory as EtlFileCategory,
classify_file as etl_classify,
)
if category == FileCategory.MARKDOWN:
return await _process_markdown_upload(ctx)
if category == FileCategory.DIRECT_CONVERT:
return await _process_direct_convert_upload(ctx)
if category == FileCategory.AUDIO:
return await _process_audio_upload(ctx)
return await _process_document_upload(ctx)
category = etl_classify(filename)
if category == EtlFileCategory.DOCUMENT:
return await _process_document_upload(ctx)
return await _process_non_document_upload(ctx)
except Exception as e:
await session.rollback()
@ -758,201 +345,64 @@ async def _extract_file_content(
Returns:
Tuple of (markdown_content, etl_service_name).
"""
category = classify_file(filename)
if category == FileCategory.MARKDOWN:
if notification:
await NotificationService.document_processing.notify_processing_progress(
session,
notification,
stage="parsing",
stage_message="Reading file",
)
await task_logger.log_task_progress(
log_entry,
f"Processing markdown/text file: {filename}",
{"file_type": "markdown", "processing_stage": "reading_file"},
)
with open(file_path, encoding="utf-8") as f:
content = f.read()
with contextlib.suppress(Exception):
os.unlink(file_path)
return content, "MARKDOWN"
if category == FileCategory.DIRECT_CONVERT:
if notification:
await NotificationService.document_processing.notify_processing_progress(
session,
notification,
stage="parsing",
stage_message="Converting file",
)
await task_logger.log_task_progress(
log_entry,
f"Direct-converting file to markdown: {filename}",
{"file_type": "direct_convert", "processing_stage": "converting"},
)
content = convert_file_directly(file_path, filename)
with contextlib.suppress(Exception):
os.unlink(file_path)
return content, "DIRECT_CONVERT"
if category == FileCategory.AUDIO:
if notification:
await NotificationService.document_processing.notify_processing_progress(
session,
notification,
stage="parsing",
stage_message="Transcribing audio",
)
await task_logger.log_task_progress(
log_entry,
f"Processing audio file for transcription: {filename}",
{"file_type": "audio", "processing_stage": "starting_transcription"},
)
transcribed_text = await _transcribe_audio(file_path, filename)
with contextlib.suppress(Exception):
os.unlink(file_path)
return transcribed_text, "AUDIO_TRANSCRIPTION"
# Document file — use ETL service
return await _extract_document_content(
file_path,
filename,
session,
user_id,
task_logger,
log_entry,
notification,
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
from app.etl_pipeline.file_classifier import (
FileCategory,
classify_file as etl_classify,
)
async def _transcribe_audio(file_path: str, filename: str) -> str:
"""Transcribe an audio file and return formatted markdown text."""
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
result = stt_service.transcribe_file(file_path)
text = result.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
else:
from litellm import atranscription
with open(file_path, "rb") as audio_file:
kwargs: dict = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
response = await atranscription(**kwargs)
text = response.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
return f"# Transcription of {filename}\n\n{text}"
async def _extract_document_content(
file_path: str,
filename: str,
session: AsyncSession,
user_id: str,
task_logger: TaskLoggingService,
log_entry: Log,
notification: Notification | None,
) -> tuple[str, str]:
"""
Parse a document file via the configured ETL service.
Returns:
Tuple of (markdown_content, etl_service_name).
"""
from app.services.page_limit_service import PageLimitService
page_limit_service = PageLimitService(session)
try:
estimated_pages = page_limit_service.estimate_pages_before_processing(file_path)
except Exception:
file_size = os.path.getsize(file_path)
estimated_pages = max(1, file_size // (80 * 1024))
await page_limit_service.check_page_limit(user_id, estimated_pages)
etl_service = app_config.ETL_SERVICE
markdown_content: str | None = None
category = etl_classify(filename)
estimated_pages = 0
if notification:
stage_messages = {
FileCategory.PLAINTEXT: "Reading file",
FileCategory.DIRECT_CONVERT: "Converting file",
FileCategory.AUDIO: "Transcribing audio",
FileCategory.UNSUPPORTED: "Unsupported file type",
FileCategory.DOCUMENT: "Extracting content",
}
await NotificationService.document_processing.notify_processing_progress(
session,
notification,
stage="parsing",
stage_message="Extracting content",
stage_message=stage_messages.get(category, "Processing"),
)
if etl_service == "UNSTRUCTURED":
from app.utils.document_converters import convert_document_to_markdown
await task_logger.log_task_progress(
log_entry,
f"Processing {category.value} file: {filename}",
{"file_type": category.value, "processing_stage": "extracting"},
)
docs = await parse_with_unstructured(file_path)
markdown_content = await convert_document_to_markdown(docs)
actual_pages = page_limit_service.estimate_pages_from_elements(docs)
final_pages = max(estimated_pages, actual_pages)
await page_limit_service.update_page_usage(
user_id, final_pages, allow_exceed=True
)
if category == FileCategory.DOCUMENT:
from app.services.page_limit_service import PageLimitService
elif etl_service == "LLAMACLOUD":
raw_result = await parse_with_llamacloud_retry(
page_limit_service = PageLimitService(session)
estimated_pages = _estimate_pages_safe(page_limit_service, file_path)
await page_limit_service.check_page_limit(user_id, estimated_pages)
result = await EtlPipelineService().extract(
EtlRequest(
file_path=file_path,
filename=filename,
estimated_pages=estimated_pages,
task_logger=task_logger,
log_entry=log_entry,
)
markdown_documents = await raw_result.aget_markdown_documents(
split_by_page=False
)
if not markdown_documents:
raise RuntimeError(f"LlamaCloud parsing returned no documents: {filename}")
markdown_content = markdown_documents[0].text
)
if category == FileCategory.DOCUMENT:
await page_limit_service.update_page_usage(
user_id, estimated_pages, allow_exceed=True
)
elif etl_service == "DOCLING":
getLogger("docling.pipeline.base_pipeline").setLevel(ERROR)
getLogger("docling.document_converter").setLevel(ERROR)
getLogger("docling_core.transforms.chunker.hierarchical_chunker").setLevel(
ERROR
)
from docling.document_converter import DocumentConverter
converter = DocumentConverter()
result = converter.convert(file_path)
markdown_content = result.document.export_to_markdown()
await page_limit_service.update_page_usage(
user_id, estimated_pages, allow_exceed=True
)
else:
raise RuntimeError(f"Unknown ETL_SERVICE: {etl_service}")
with contextlib.suppress(Exception):
os.unlink(file_path)
if not markdown_content:
if not result.markdown_content:
raise RuntimeError(f"Failed to extract content from file: {filename}")
return markdown_content, etl_service
return result.markdown_content, result.etl_service
async def process_file_in_background_with_document(

View file

@ -0,0 +1,124 @@
"""Per-parser document extension sets for the ETL pipeline.
Every consumer (file_classifier, connector-level skip checks, ETL pipeline
validation) imports from here so there is a single source of truth.
Extensions already covered by PLAINTEXT_EXTENSIONS, AUDIO_EXTENSIONS, or
DIRECT_CONVERT_EXTENSIONS in file_classifier are NOT repeated here -- these
sets are exclusively for the "document" ETL path (Docling / LlamaParse /
Unstructured).
"""
from pathlib import PurePosixPath
# ---------------------------------------------------------------------------
# Per-parser document extension sets (from official documentation)
# ---------------------------------------------------------------------------
DOCLING_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
{
".pdf",
".docx",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
".tiff",
".tif",
".bmp",
".webp",
}
)
LLAMAPARSE_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
{
".pdf",
".docx",
".doc",
".xlsx",
".xls",
".pptx",
".ppt",
".docm",
".dot",
".dotm",
".pptm",
".pot",
".potx",
".xlsm",
".xlsb",
".xlw",
".rtf",
".epub",
".png",
".jpg",
".jpeg",
".gif",
".bmp",
".tiff",
".tif",
".webp",
".svg",
".odt",
".ods",
".odp",
".hwp",
".hwpx",
}
)
UNSTRUCTURED_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
{
".pdf",
".docx",
".doc",
".xlsx",
".xls",
".pptx",
".ppt",
".png",
".jpg",
".jpeg",
".bmp",
".tiff",
".tif",
".heic",
".rtf",
".epub",
".odt",
".eml",
".msg",
".p7s",
}
)
# ---------------------------------------------------------------------------
# Union (used by classify_file for routing) + service lookup
# ---------------------------------------------------------------------------
DOCUMENT_EXTENSIONS: frozenset[str] = (
DOCLING_DOCUMENT_EXTENSIONS
| LLAMAPARSE_DOCUMENT_EXTENSIONS
| UNSTRUCTURED_DOCUMENT_EXTENSIONS
)
_SERVICE_MAP: dict[str, frozenset[str]] = {
"DOCLING": DOCLING_DOCUMENT_EXTENSIONS,
"LLAMACLOUD": LLAMAPARSE_DOCUMENT_EXTENSIONS,
"UNSTRUCTURED": UNSTRUCTURED_DOCUMENT_EXTENSIONS,
}
def get_document_extensions_for_service(etl_service: str | None) -> frozenset[str]:
"""Return the document extensions supported by *etl_service*.
Falls back to the full union when the service is ``None`` or unknown.
"""
return _SERVICE_MAP.get(etl_service or "", DOCUMENT_EXTENSIONS)
def is_supported_document_extension(filename: str) -> bool:
"""Return True if the file's extension is in the supported document set."""
suffix = PurePosixPath(filename).suffix.lower()
return suffix in DOCUMENT_EXTENSIONS

View file

@ -319,31 +319,23 @@ def _mock_etl_parsing(monkeypatch):
# -- LlamaParse mock (external API) --------------------------------
class _FakeMarkdownDoc:
def __init__(self, text: str):
self.text = text
class _FakeLlamaParseResult:
async def aget_markdown_documents(self, *, split_by_page=False):
return [_FakeMarkdownDoc(_MOCK_ETL_MARKDOWN)]
async def _fake_llamacloud_parse(**kwargs):
_reject_empty(kwargs["file_path"])
return _FakeLlamaParseResult()
async def _fake_llamacloud_parse(file_path: str, estimated_pages: int) -> str:
_reject_empty(file_path)
return _MOCK_ETL_MARKDOWN
monkeypatch.setattr(
"app.tasks.document_processors.file_processors.parse_with_llamacloud_retry",
"app.etl_pipeline.parsers.llamacloud.parse_with_llamacloud",
_fake_llamacloud_parse,
)
# -- Docling mock (heavy library boundary) -------------------------
async def _fake_docling_parse(file_path: str, filename: str):
async def _fake_docling_parse(file_path: str, filename: str) -> str:
_reject_empty(file_path)
return _MOCK_ETL_MARKDOWN
monkeypatch.setattr(
"app.tasks.document_processors.file_processors.parse_with_docling",
"app.etl_pipeline.parsers.docling.parse_with_docling",
_fake_docling_parse,
)

View file

@ -124,7 +124,7 @@ async def test_composio_connector_without_account_id_returns_error(
maker = make_session_factory(async_engine)
async with maker() as session:
count, _skipped, error = await index_google_drive_files(
count, _skipped, error, _unsupported = await index_google_drive_files(
session=session,
connector_id=data["connector_id"],
search_space_id=data["search_space_id"],

View file

@ -0,0 +1,244 @@
"""Tests that each cloud connector's download_and_extract_content correctly
produces markdown from a real file via the unified ETL pipeline.
Only the cloud client is mocked (system boundary). The ETL pipeline runs for
real so we know the full path from "cloud gives us bytes" to "we get markdown
back" actually works.
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
pytestmark = pytest.mark.unit
_TXT_CONTENT = "Hello from the cloud connector test."
_CSV_CONTENT = "name,age\nAlice,30\nBob,25\n"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _write_file(dest_path: str, content: str) -> None:
"""Simulate a cloud client writing downloaded bytes to disk."""
with open(dest_path, "w", encoding="utf-8") as f:
f.write(content)
def _make_download_side_effect(content: str):
"""Return an async side-effect that writes *content* to the dest path
and returns ``None`` (success)."""
async def _side_effect(*args):
dest_path = args[-1]
await _write_file(dest_path, content)
return None
return _side_effect
# ===================================================================
# Google Drive
# ===================================================================
class TestGoogleDriveContentExtraction:
async def test_txt_file_returns_markdown(self):
from app.connectors.google_drive.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(
side_effect=_make_download_side_effect(_TXT_CONTENT),
)
file = {"id": "f1", "name": "notes.txt", "mimeType": "text/plain"}
markdown, metadata, error = await download_and_extract_content(client, file)
assert error is None
assert _TXT_CONTENT in markdown
assert metadata["google_drive_file_id"] == "f1"
assert metadata["google_drive_file_name"] == "notes.txt"
async def test_csv_file_returns_markdown_table(self):
from app.connectors.google_drive.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(
side_effect=_make_download_side_effect(_CSV_CONTENT),
)
file = {"id": "f2", "name": "data.csv", "mimeType": "text/csv"}
markdown, _metadata, error = await download_and_extract_content(client, file)
assert error is None
assert "Alice" in markdown
assert "Bob" in markdown
assert "|" in markdown
async def test_download_error_returns_error_message(self):
from app.connectors.google_drive.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(return_value="Network timeout")
file = {"id": "f3", "name": "doc.txt", "mimeType": "text/plain"}
markdown, _metadata, error = await download_and_extract_content(client, file)
assert markdown is None
assert error == "Network timeout"
# ===================================================================
# OneDrive
# ===================================================================
class TestOneDriveContentExtraction:
async def test_txt_file_returns_markdown(self):
from app.connectors.onedrive.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(
side_effect=_make_download_side_effect(_TXT_CONTENT),
)
file = {
"id": "od-1",
"name": "report.txt",
"file": {"mimeType": "text/plain"},
}
markdown, metadata, error = await download_and_extract_content(client, file)
assert error is None
assert _TXT_CONTENT in markdown
assert metadata["onedrive_file_id"] == "od-1"
assert metadata["onedrive_file_name"] == "report.txt"
async def test_csv_file_returns_markdown_table(self):
from app.connectors.onedrive.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(
side_effect=_make_download_side_effect(_CSV_CONTENT),
)
file = {
"id": "od-2",
"name": "data.csv",
"file": {"mimeType": "text/csv"},
}
markdown, _metadata, error = await download_and_extract_content(client, file)
assert error is None
assert "Alice" in markdown
assert "|" in markdown
async def test_download_error_returns_error_message(self):
from app.connectors.onedrive.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(return_value="403 Forbidden")
file = {
"id": "od-3",
"name": "secret.txt",
"file": {"mimeType": "text/plain"},
}
markdown, _metadata, error = await download_and_extract_content(client, file)
assert markdown is None
assert error == "403 Forbidden"
# ===================================================================
# Dropbox
# ===================================================================
class TestDropboxContentExtraction:
async def test_txt_file_returns_markdown(self):
from app.connectors.dropbox.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(
side_effect=_make_download_side_effect(_TXT_CONTENT),
)
file = {
"id": "dbx-1",
"name": "memo.txt",
".tag": "file",
"path_lower": "/memo.txt",
}
markdown, metadata, error = await download_and_extract_content(client, file)
assert error is None
assert _TXT_CONTENT in markdown
assert metadata["dropbox_file_id"] == "dbx-1"
assert metadata["dropbox_file_name"] == "memo.txt"
async def test_csv_file_returns_markdown_table(self):
from app.connectors.dropbox.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(
side_effect=_make_download_side_effect(_CSV_CONTENT),
)
file = {
"id": "dbx-2",
"name": "data.csv",
".tag": "file",
"path_lower": "/data.csv",
}
markdown, _metadata, error = await download_and_extract_content(client, file)
assert error is None
assert "Alice" in markdown
assert "|" in markdown
async def test_download_error_returns_error_message(self):
from app.connectors.dropbox.content_extractor import (
download_and_extract_content,
)
client = MagicMock()
client.download_file_to_disk = AsyncMock(return_value="Rate limited")
file = {
"id": "dbx-3",
"name": "big.txt",
".tag": "file",
"path_lower": "/big.txt",
}
markdown, _metadata, error = await download_and_extract_content(client, file)
assert markdown is None
assert error == "Rate limited"

View file

@ -8,6 +8,10 @@ import pytest
from app.db import DocumentType
from app.tasks.connector_indexers.dropbox_indexer import (
_download_files_parallel,
_index_full_scan,
_index_selected_files,
_index_with_delta_sync,
index_dropbox_files,
)
pytestmark = pytest.mark.unit
@ -234,3 +238,610 @@ async def test_heartbeat_fires_during_parallel_downloads(
assert len(docs) == 3
assert failed == 0
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
# ---------------------------------------------------------------------------
# D1-D2: _index_full_scan tests
# ---------------------------------------------------------------------------
def _folder_dict(name: str) -> dict:
return {".tag": "folder", "name": name}
@pytest.fixture
def full_scan_mocks(mock_dropbox_client, monkeypatch):
"""Wire up mocks for _index_full_scan in isolation."""
import app.tasks.connector_indexers.dropbox_indexer as _mod
mock_session = AsyncMock()
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
mock_log_entry = MagicMock()
skip_results: dict[str, tuple[bool, str | None]] = {}
monkeypatch.setattr("app.config.config.ETL_SERVICE", "LLAMACLOUD")
async def _fake_skip(session, file, search_space_id):
from app.connectors.dropbox.file_types import should_skip_file as _skip
item_skip, unsup_ext = _skip(file)
if item_skip:
if unsup_ext:
return True, f"unsupported:{unsup_ext}"
return True, "folder/non-downloadable"
return skip_results.get(file.get("id", ""), (False, None))
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
download_and_index_mock = AsyncMock(return_value=(0, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
from app.services.page_limit_service import PageLimitService as _RealPLS
mock_page_limit_instance = MagicMock()
mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999))
mock_page_limit_instance.update_page_usage = AsyncMock()
class _MockPageLimitService:
estimate_pages_from_metadata = staticmethod(
_RealPLS.estimate_pages_from_metadata
)
def __init__(self, session):
self.get_page_usage = mock_page_limit_instance.get_page_usage
self.update_page_usage = mock_page_limit_instance.update_page_usage
monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService)
return {
"dropbox_client": mock_dropbox_client,
"session": mock_session,
"task_logger": mock_task_logger,
"log_entry": mock_log_entry,
"skip_results": skip_results,
"download_and_index_mock": download_and_index_mock,
}
async def _run_full_scan(mocks, monkeypatch, page_files, *, max_files=500):
import app.tasks.connector_indexers.dropbox_indexer as _mod
monkeypatch.setattr(
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None)),
)
return await _index_full_scan(
mocks["dropbox_client"],
mocks["session"],
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"",
"Root",
mocks["task_logger"],
mocks["log_entry"],
max_files,
enable_summary=True,
)
async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
"""Skipped files excluded, renames counted as indexed, new files downloaded."""
page_files = [
_folder_dict("SubFolder"),
_make_file_dict("skip1", "unchanged.txt"),
_make_file_dict("rename1", "renamed.txt"),
_make_file_dict("new1", "new1.txt"),
_make_file_dict("new2", "new2.txt"),
]
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
full_scan_mocks["skip_results"]["rename1"] = (
True,
"File renamed: 'old' -> 'renamed.txt'",
)
full_scan_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, _unsupported = await _run_full_scan(
full_scan_mocks, monkeypatch, page_files
)
assert indexed == 3 # 1 renamed + 2 from batch
assert skipped == 2 # 1 folder + 1 unchanged
call_args = full_scan_mocks["download_and_index_mock"].call_args
call_files = call_args[0][2]
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"new1", "new2"}
async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
"""Only max_files non-folder items are considered."""
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
full_scan_mocks["download_and_index_mock"].return_value = (3, 0)
await _run_full_scan(full_scan_mocks, monkeypatch, page_files, max_files=3)
call_files = full_scan_mocks["download_and_index_mock"].call_args[0][2]
assert len(call_files) == 3
# ---------------------------------------------------------------------------
# D3-D5: _index_selected_files tests
# ---------------------------------------------------------------------------
@pytest.fixture
def selected_files_mocks(mock_dropbox_client, monkeypatch):
"""Wire up mocks for _index_selected_files tests."""
import app.tasks.connector_indexers.dropbox_indexer as _mod
mock_session = AsyncMock()
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
async def _fake_get_file(client, path):
return get_file_results.get(path, (None, f"Not configured: {path}"))
monkeypatch.setattr(_mod, "get_file_by_path", _fake_get_file)
skip_results: dict[str, tuple[bool, str | None]] = {}
async def _fake_skip(session, file, search_space_id):
return skip_results.get(file["id"], (False, None))
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
download_and_index_mock = AsyncMock(return_value=(0, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
from app.services.page_limit_service import PageLimitService as _RealPLS
mock_page_limit_instance = MagicMock()
mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999))
mock_page_limit_instance.update_page_usage = AsyncMock()
class _MockPageLimitService:
estimate_pages_from_metadata = staticmethod(
_RealPLS.estimate_pages_from_metadata
)
def __init__(self, session):
self.get_page_usage = mock_page_limit_instance.get_page_usage
self.update_page_usage = mock_page_limit_instance.update_page_usage
monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService)
return {
"dropbox_client": mock_dropbox_client,
"session": mock_session,
"get_file_results": get_file_results,
"skip_results": skip_results,
"download_and_index_mock": download_and_index_mock,
}
async def _run_selected(mocks, file_tuples):
return await _index_selected_files(
mocks["dropbox_client"],
mocks["session"],
file_tuples,
connector_id=_CONNECTOR_ID,
search_space_id=_SEARCH_SPACE_ID,
user_id=_USER_ID,
enable_summary=True,
)
async def test_selected_files_single_file_indexed(selected_files_mocks):
selected_files_mocks["get_file_results"]["/report.pdf"] = (
_make_file_dict("f1", "report.pdf"),
None,
)
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
indexed, skipped, _unsupported, errors = await _run_selected(
selected_files_mocks,
[("/report.pdf", "report.pdf")],
)
assert indexed == 1
assert skipped == 0
assert errors == []
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
selected_files_mocks["get_file_results"]["/first.txt"] = (
_make_file_dict("f1", "first.txt"),
None,
)
selected_files_mocks["get_file_results"]["/mid.txt"] = (None, "HTTP 404")
selected_files_mocks["get_file_results"]["/third.txt"] = (
_make_file_dict("f3", "third.txt"),
None,
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, _unsupported, errors = await _run_selected(
selected_files_mocks,
[
("/first.txt", "first.txt"),
("/mid.txt", "mid.txt"),
("/third.txt", "third.txt"),
],
)
assert indexed == 2
assert skipped == 0
assert len(errors) == 1
assert "mid.txt" in errors[0]
async def test_selected_files_skip_rename_counting(selected_files_mocks):
for path, fid, fname in [
("/unchanged.txt", "s1", "unchanged.txt"),
("/renamed.txt", "r1", "renamed.txt"),
("/new1.txt", "n1", "new1.txt"),
("/new2.txt", "n2", "new2.txt"),
]:
selected_files_mocks["get_file_results"][path] = (
_make_file_dict(fid, fname),
None,
)
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
selected_files_mocks["skip_results"]["r1"] = (
True,
"File renamed: 'old' -> 'renamed.txt'",
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, _unsupported, errors = await _run_selected(
selected_files_mocks,
[
("/unchanged.txt", "unchanged.txt"),
("/renamed.txt", "renamed.txt"),
("/new1.txt", "new1.txt"),
("/new2.txt", "new2.txt"),
],
)
assert indexed == 3 # 1 renamed + 2 batch
assert skipped == 1
assert errors == []
mock = selected_files_mocks["download_and_index_mock"]
call_files = mock.call_args[0][2]
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"n1", "n2"}
# ---------------------------------------------------------------------------
# E1-E4: _index_with_delta_sync tests
# ---------------------------------------------------------------------------
async def test_delta_sync_deletions_call_remove_document(monkeypatch):
"""E1: deleted entries are processed via _remove_document."""
import app.tasks.connector_indexers.dropbox_indexer as _mod
entries = [
{
".tag": "deleted",
"name": "gone.txt",
"path_lower": "/gone.txt",
"id": "id:del1",
},
{
".tag": "deleted",
"name": "also_gone.pdf",
"path_lower": "/also_gone.pdf",
"id": "id:del2",
},
]
mock_client = MagicMock()
mock_client.get_changes = AsyncMock(return_value=(entries, "new-cursor", None))
remove_calls: list[str] = []
async def _fake_remove(session, file_id, search_space_id):
remove_calls.append(file_id)
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
monkeypatch.setattr(_mod, "_download_and_index", AsyncMock(return_value=(0, 0)))
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
_indexed, _skipped, _unsupported, cursor = await _index_with_delta_sync(
mock_client,
AsyncMock(),
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"old-cursor",
mock_task_logger,
MagicMock(),
max_files=500,
enable_summary=True,
)
assert sorted(remove_calls) == ["id:del1", "id:del2"]
assert cursor == "new-cursor"
async def test_delta_sync_upserts_filtered_and_downloaded(monkeypatch):
"""E2: modified/new file entries go through skip filter then download+index."""
import app.tasks.connector_indexers.dropbox_indexer as _mod
entries = [
_make_file_dict("mod1", "modified1.txt"),
_make_file_dict("mod2", "modified2.txt"),
]
mock_client = MagicMock()
mock_client.get_changes = AsyncMock(return_value=(entries, "cursor-v2", None))
monkeypatch.setattr(
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
)
download_mock = AsyncMock(return_value=(2, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_mock)
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
mock_client,
AsyncMock(),
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"cursor-v1",
mock_task_logger,
MagicMock(),
max_files=500,
enable_summary=True,
)
assert indexed == 2
assert skipped == 0
assert cursor == "cursor-v2"
downloaded_files = download_mock.call_args[0][2]
assert len(downloaded_files) == 2
assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"}
async def test_delta_sync_mix_deletions_and_upserts(monkeypatch):
"""E3: deletions processed, then remaining upserts filtered and indexed."""
import app.tasks.connector_indexers.dropbox_indexer as _mod
entries = [
{
".tag": "deleted",
"name": "removed.txt",
"path_lower": "/removed.txt",
"id": "id:del1",
},
{
".tag": "deleted",
"name": "trashed.pdf",
"path_lower": "/trashed.pdf",
"id": "id:del2",
},
_make_file_dict("mod1", "updated.txt"),
_make_file_dict("new1", "brandnew.docx"),
]
mock_client = MagicMock()
mock_client.get_changes = AsyncMock(return_value=(entries, "final-cursor", None))
remove_calls: list[str] = []
async def _fake_remove(session, file_id, search_space_id):
remove_calls.append(file_id)
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
monkeypatch.setattr(
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
)
download_mock = AsyncMock(return_value=(2, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_mock)
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
mock_client,
AsyncMock(),
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"old-cursor",
mock_task_logger,
MagicMock(),
max_files=500,
enable_summary=True,
)
assert sorted(remove_calls) == ["id:del1", "id:del2"]
assert indexed == 2
assert skipped == 0
assert cursor == "final-cursor"
downloaded_files = download_mock.call_args[0][2]
assert {f["id"] for f in downloaded_files} == {"mod1", "new1"}
async def test_delta_sync_returns_new_cursor(monkeypatch):
"""E4: the new cursor from the API response is returned."""
import app.tasks.connector_indexers.dropbox_indexer as _mod
mock_client = MagicMock()
mock_client.get_changes = AsyncMock(return_value=([], "brand-new-cursor-xyz", None))
monkeypatch.setattr(_mod, "_download_and_index", AsyncMock(return_value=(0, 0)))
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
mock_client,
AsyncMock(),
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
"old-cursor",
mock_task_logger,
MagicMock(),
max_files=500,
enable_summary=True,
)
assert cursor == "brand-new-cursor-xyz"
assert indexed == 0
assert skipped == 0
# ---------------------------------------------------------------------------
# F1-F3: index_dropbox_files orchestrator tests
# ---------------------------------------------------------------------------
@pytest.fixture
def orchestrator_mocks(monkeypatch):
"""Wire up mocks for index_dropbox_files orchestrator tests."""
import app.tasks.connector_indexers.dropbox_indexer as _mod
mock_connector = MagicMock()
mock_connector.config = {"_token_encrypted": False}
mock_connector.last_indexed_at = None
mock_connector.enable_summary = True
monkeypatch.setattr(
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
mock_task_logger = MagicMock()
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
mock_task_logger.log_task_progress = AsyncMock()
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger)
)
monkeypatch.setattr(_mod, "update_connector_last_indexed", AsyncMock())
full_scan_mock = AsyncMock(return_value=(5, 2, 0))
monkeypatch.setattr(_mod, "_index_full_scan", full_scan_mock)
delta_sync_mock = AsyncMock(return_value=(3, 1, 0, "delta-cursor-new"))
monkeypatch.setattr(_mod, "_index_with_delta_sync", delta_sync_mock)
mock_client = MagicMock()
mock_client.get_latest_cursor = AsyncMock(return_value=("latest-cursor-abc", None))
monkeypatch.setattr(_mod, "DropboxClient", MagicMock(return_value=mock_client))
return {
"connector": mock_connector,
"full_scan_mock": full_scan_mock,
"delta_sync_mock": delta_sync_mock,
"mock_client": mock_client,
}
async def test_orchestrator_uses_delta_sync_when_cursor_and_last_indexed(
orchestrator_mocks,
):
"""F1: with cursor + last_indexed_at + use_delta_sync, calls delta sync."""
from datetime import UTC, datetime
connector = orchestrator_mocks["connector"]
connector.config = {
"_token_encrypted": False,
"folder_cursors": {"/docs": "saved-cursor-123"},
}
connector.last_indexed_at = datetime(2026, 1, 1, tzinfo=UTC)
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
_indexed, _skipped, error, _unsupported = await index_dropbox_files(
mock_session,
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
{
"folders": [{"path": "/docs", "name": "Docs"}],
"files": [],
"indexing_options": {"use_delta_sync": True},
},
)
assert error is None
orchestrator_mocks["delta_sync_mock"].assert_called_once()
orchestrator_mocks["full_scan_mock"].assert_not_called()
async def test_orchestrator_falls_back_to_full_scan_without_cursor(
orchestrator_mocks,
):
"""F2: without cursor, falls back to full scan."""
connector = orchestrator_mocks["connector"]
connector.config = {"_token_encrypted": False}
connector.last_indexed_at = None
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
_indexed, _skipped, error, _unsupported = await index_dropbox_files(
mock_session,
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
{
"folders": [{"path": "/docs", "name": "Docs"}],
"files": [],
"indexing_options": {"use_delta_sync": True},
},
)
assert error is None
orchestrator_mocks["full_scan_mock"].assert_called_once()
orchestrator_mocks["delta_sync_mock"].assert_not_called()
async def test_orchestrator_persists_cursor_after_sync(orchestrator_mocks):
"""F3: after sync, persists new cursor to connector config."""
connector = orchestrator_mocks["connector"]
connector.config = {"_token_encrypted": False}
connector.last_indexed_at = None
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
await index_dropbox_files(
mock_session,
_CONNECTOR_ID,
_SEARCH_SPACE_ID,
_USER_ID,
{
"folders": [{"path": "/docs", "name": "Docs"}],
"files": [],
},
)
assert "folder_cursors" in connector.config
assert connector.config["folder_cursors"]["/docs"] == "latest-cursor-abc"

View file

@ -366,7 +366,7 @@ async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
full_scan_mocks["batch_mock"].return_value = ([], 2, 0)
indexed, skipped = await _run_full_scan(full_scan_mocks)
indexed, skipped, _unsupported = await _run_full_scan(full_scan_mocks)
assert indexed == 3 # 1 renamed + 2 from batch
assert skipped == 1 # 1 unchanged
@ -497,7 +497,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
indexed, skipped = await _index_with_delta_sync(
indexed, skipped, _unsupported = await _index_with_delta_sync(
MagicMock(),
mock_session,
MagicMock(),
@ -589,7 +589,7 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
)
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
indexed, skipped, errors = await _run_selected(
indexed, skipped, _unsup, errors = await _run_selected(
selected_files_mocks,
[("f1", "report.pdf")],
)
@ -613,7 +613,7 @@ async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
indexed, skipped, _unsup, errors = await _run_selected(
selected_files_mocks,
[("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")],
)
@ -647,7 +647,7 @@ async def test_selected_files_skip_rename_counting(selected_files_mocks):
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
indexed, skipped, _unsup, errors = await _run_selected(
selected_files_mocks,
[
("s1", "unchanged.txt"),

View file

@ -198,7 +198,7 @@ async def test_gdrive_files_within_quota_are_downloaded(gdrive_selected_mocks):
)
m["download_and_index_mock"].return_value = (3, 0)
indexed, _skipped, errors = await _run_gdrive_selected(
indexed, _skipped, _unsup, errors = await _run_gdrive_selected(
m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")]
)
@ -219,7 +219,9 @@ async def test_gdrive_files_exceeding_quota_rejected(gdrive_selected_mocks):
None,
)
indexed, _skipped, errors = await _run_gdrive_selected(m, [("big", "huge.pdf")])
indexed, _skipped, _unsup, errors = await _run_gdrive_selected(
m, [("big", "huge.pdf")]
)
assert indexed == 0
assert len(errors) == 1
@ -239,7 +241,7 @@ async def test_gdrive_quota_mix_partial_indexing(gdrive_selected_mocks):
)
m["download_and_index_mock"].return_value = (2, 0)
indexed, _skipped, errors = await _run_gdrive_selected(
indexed, _skipped, _unsup, errors = await _run_gdrive_selected(
m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")]
)
@ -299,7 +301,7 @@ async def test_gdrive_zero_quota_rejects_all(gdrive_selected_mocks):
None,
)
indexed, _skipped, errors = await _run_gdrive_selected(
indexed, _skipped, _unsup, errors = await _run_gdrive_selected(
m, [("f1", "f1.xyz"), ("f2", "f2.xyz")]
)
@ -384,7 +386,7 @@ async def test_gdrive_full_scan_skips_over_quota(gdrive_full_scan_mocks, monkeyp
m["download_mock"].return_value = ([], 0)
m["batch_mock"].return_value = ([], 2, 0)
_indexed, skipped = await _run_gdrive_full_scan(m)
_indexed, skipped, _unsup = await _run_gdrive_full_scan(m)
call_files = m["download_mock"].call_args[0][1]
assert len(call_files) == 2
@ -459,7 +461,7 @@ async def test_gdrive_delta_sync_skips_over_quota(monkeypatch):
mock_task_logger = MagicMock()
mock_task_logger.log_task_progress = AsyncMock()
_indexed, skipped = await _mod._index_with_delta_sync(
_indexed, skipped, _unsupported = await _mod._index_with_delta_sync(
MagicMock(),
session,
MagicMock(),
@ -552,7 +554,9 @@ async def test_onedrive_over_quota_rejected(onedrive_selected_mocks):
None,
)
indexed, _skipped, errors = await _run_onedrive_selected(m, [("big", "huge.pdf")])
indexed, _skipped, _unsup, errors = await _run_onedrive_selected(
m, [("big", "huge.pdf")]
)
assert indexed == 0
assert len(errors) == 1
@ -652,7 +656,7 @@ async def test_dropbox_over_quota_rejected(dropbox_selected_mocks):
None,
)
indexed, _skipped, errors = await _run_dropbox_selected(
indexed, _skipped, _unsup, errors = await _run_dropbox_selected(
m, [("/huge.pdf", "huge.pdf")]
)

View file

@ -0,0 +1,123 @@
"""Tests for DropboxClient delta-sync methods (get_latest_cursor, get_changes)."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.connectors.dropbox.client import DropboxClient
pytestmark = pytest.mark.unit
def _make_client() -> DropboxClient:
"""Create a DropboxClient with a mocked DB session so no real DB needed."""
client = DropboxClient.__new__(DropboxClient)
client._session = MagicMock()
client._connector_id = 1
return client
# ---------- C1: get_latest_cursor ----------
async def test_get_latest_cursor_returns_cursor_string(monkeypatch):
client = _make_client()
fake_resp = MagicMock()
fake_resp.status_code = 200
fake_resp.json.return_value = {"cursor": "AAHbKxRZ9enq…"}
monkeypatch.setattr(client, "_request", AsyncMock(return_value=fake_resp))
cursor, error = await client.get_latest_cursor("/my-folder")
assert cursor == "AAHbKxRZ9enq…"
assert error is None
client._request.assert_called_once_with(
"/2/files/list_folder/get_latest_cursor",
{
"path": "/my-folder",
"recursive": False,
"include_non_downloadable_files": True,
},
)
# ---------- C2: get_changes returns entries and new cursor ----------
async def test_get_changes_returns_entries_and_cursor(monkeypatch):
client = _make_client()
fake_resp = MagicMock()
fake_resp.status_code = 200
fake_resp.json.return_value = {
"entries": [
{".tag": "file", "name": "new.txt", "id": "id:abc"},
{".tag": "deleted", "name": "old.txt"},
],
"cursor": "cursor-v2",
"has_more": False,
}
monkeypatch.setattr(client, "_request", AsyncMock(return_value=fake_resp))
entries, new_cursor, error = await client.get_changes("cursor-v1")
assert error is None
assert new_cursor == "cursor-v2"
assert len(entries) == 2
assert entries[0]["name"] == "new.txt"
assert entries[1][".tag"] == "deleted"
# ---------- C3: get_changes handles pagination ----------
async def test_get_changes_handles_pagination(monkeypatch):
client = _make_client()
page1 = MagicMock()
page1.status_code = 200
page1.json.return_value = {
"entries": [{".tag": "file", "name": "a.txt", "id": "id:a"}],
"cursor": "cursor-page2",
"has_more": True,
}
page2 = MagicMock()
page2.status_code = 200
page2.json.return_value = {
"entries": [{".tag": "file", "name": "b.txt", "id": "id:b"}],
"cursor": "cursor-final",
"has_more": False,
}
request_mock = AsyncMock(side_effect=[page1, page2])
monkeypatch.setattr(client, "_request", request_mock)
entries, new_cursor, error = await client.get_changes("cursor-v1")
assert error is None
assert new_cursor == "cursor-final"
assert len(entries) == 2
assert {e["name"] for e in entries} == {"a.txt", "b.txt"}
assert request_mock.call_count == 2
# ---------- C4: get_changes raises on 401 ----------
async def test_get_changes_returns_error_on_401(monkeypatch):
client = _make_client()
fake_resp = MagicMock()
fake_resp.status_code = 401
fake_resp.text = "Unauthorized"
monkeypatch.setattr(client, "_request", AsyncMock(return_value=fake_resp))
entries, new_cursor, error = await client.get_changes("old-cursor")
assert error is not None
assert "401" in error
assert entries == []
assert new_cursor is None

View file

@ -0,0 +1,173 @@
"""Tests for Dropbox file type filtering (should_skip_file)."""
import pytest
from app.connectors.dropbox.file_types import should_skip_file
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Structural skips (independent of ETL service)
# ---------------------------------------------------------------------------
def test_folder_item_is_skipped():
item = {".tag": "folder", "name": "My Folder"}
skip, ext = should_skip_file(item)
assert skip is True
assert ext is None
def test_paper_file_is_not_skipped():
item = {".tag": "file", "name": "notes.paper", "is_downloadable": False}
skip, ext = should_skip_file(item)
assert skip is False
assert ext is None
def test_non_downloadable_item_is_skipped():
item = {".tag": "file", "name": "locked.gdoc", "is_downloadable": False}
skip, ext = should_skip_file(item)
assert skip is True
assert ext is None
# ---------------------------------------------------------------------------
# Extension-based skips (require ETL service context)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"filename",
[
"archive.zip",
"backup.tar",
"data.gz",
"stuff.rar",
"pack.7z",
"program.exe",
"lib.dll",
"module.so",
"image.dmg",
"disk.iso",
"movie.mov",
"clip.avi",
"video.mkv",
"film.wmv",
"stream.flv",
"favicon.ico",
"raw.cr2",
"photo.nef",
"image.arw",
"pic.dng",
"design.psd",
"vector.ai",
"mockup.sketch",
"proto.fig",
"font.ttf",
"font.otf",
"font.woff",
"font.woff2",
"model.stl",
"scene.fbx",
"mesh.blend",
"local.db",
"data.sqlite",
"access.mdb",
],
)
def test_non_parseable_extensions_are_skipped(filename, mocker):
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
item = {".tag": "file", "name": filename}
skip, ext = should_skip_file(item)
assert skip is True, f"{filename} should be skipped"
assert ext is not None
@pytest.mark.parametrize(
"filename",
[
"report.pdf",
"document.docx",
"sheet.xlsx",
"slides.pptx",
"readme.txt",
"data.csv",
"page.html",
"notes.md",
"config.json",
"feed.xml",
],
)
def test_parseable_documents_are_not_skipped(filename, mocker):
"""Files in plaintext/direct_convert/universal document sets are never skipped."""
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
mocker.patch("app.config.config.ETL_SERVICE", service)
item = {".tag": "file", "name": filename}
skip, ext = should_skip_file(item)
assert skip is False, f"{filename} should NOT be skipped with {service}"
assert ext is None
@pytest.mark.parametrize(
"filename",
["photo.jpg", "image.jpeg", "screenshot.png", "scan.bmp", "page.tiff", "doc.tif"],
)
def test_universal_images_are_not_skipped(filename, mocker):
"""Images supported by all parsers are never skipped."""
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
mocker.patch("app.config.config.ETL_SERVICE", service)
item = {".tag": "file", "name": filename}
skip, ext = should_skip_file(item)
assert skip is False, f"{filename} should NOT be skipped with {service}"
assert ext is None
@pytest.mark.parametrize(
"filename,service,expected_skip",
[
("old.doc", "DOCLING", True),
("old.doc", "LLAMACLOUD", False),
("old.doc", "UNSTRUCTURED", False),
("legacy.xls", "DOCLING", True),
("legacy.xls", "LLAMACLOUD", False),
("legacy.xls", "UNSTRUCTURED", False),
("deck.ppt", "DOCLING", True),
("deck.ppt", "LLAMACLOUD", False),
("deck.ppt", "UNSTRUCTURED", False),
("icon.svg", "DOCLING", True),
("icon.svg", "LLAMACLOUD", False),
("anim.gif", "DOCLING", True),
("anim.gif", "LLAMACLOUD", False),
("photo.webp", "DOCLING", False),
("photo.webp", "LLAMACLOUD", False),
("photo.webp", "UNSTRUCTURED", True),
("live.heic", "DOCLING", True),
("live.heic", "UNSTRUCTURED", False),
("macro.docm", "DOCLING", True),
("macro.docm", "LLAMACLOUD", False),
("mail.eml", "DOCLING", True),
("mail.eml", "UNSTRUCTURED", False),
],
)
def test_parser_specific_extensions(filename, service, expected_skip, mocker):
mocker.patch("app.config.config.ETL_SERVICE", service)
item = {".tag": "file", "name": filename}
skip, ext = should_skip_file(item)
assert skip is expected_skip, (
f"{filename} with {service}: expected skip={expected_skip}"
)
if expected_skip:
assert ext is not None
else:
assert ext is None
def test_returns_unsupported_extension(mocker):
"""When a file is skipped due to unsupported extension, the ext string is returned."""
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
item = {".tag": "file", "name": "old.doc"}
skip, ext = should_skip_file(item)
assert skip is True
assert ext == ".doc"

View file

@ -0,0 +1,43 @@
"""Test that Dropbox re-auth preserves folder_cursors in connector config."""
import pytest
pytestmark = pytest.mark.unit
def test_reauth_preserves_folder_cursors():
"""G1: re-authentication preserves folder_cursors alongside cursor."""
old_config = {
"access_token": "old-token-enc",
"refresh_token": "old-refresh-enc",
"cursor": "old-cursor-abc",
"folder_cursors": {"/docs": "cursor-docs-123", "/photos": "cursor-photos-456"},
"_token_encrypted": True,
"auth_expired": True,
}
new_connector_config = {
"access_token": "new-token-enc",
"refresh_token": "new-refresh-enc",
"token_type": "bearer",
"expires_in": 14400,
"expires_at": "2026-04-06T16:00:00+00:00",
"_token_encrypted": True,
}
existing_cursor = old_config.get("cursor")
existing_folder_cursors = old_config.get("folder_cursors")
merged_config = {
**new_connector_config,
"cursor": existing_cursor,
"folder_cursors": existing_folder_cursors,
"auth_expired": False,
}
assert merged_config["access_token"] == "new-token-enc"
assert merged_config["cursor"] == "old-cursor-abc"
assert merged_config["folder_cursors"] == {
"/docs": "cursor-docs-123",
"/photos": "cursor-photos-456",
}
assert merged_config["auth_expired"] is False

View file

@ -0,0 +1,80 @@
"""Tests for Google Drive file type filtering."""
import pytest
from app.connectors.google_drive.file_types import should_skip_by_extension
pytestmark = pytest.mark.unit
@pytest.mark.parametrize(
"filename",
[
"malware.exe",
"archive.zip",
"video.mov",
"font.woff2",
"model.blend",
],
)
def test_unsupported_extensions_are_skipped_regardless_of_service(filename, mocker):
"""Truly unsupported files are skipped no matter which ETL service is configured."""
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
mocker.patch("app.config.config.ETL_SERVICE", service)
skip, _ext = should_skip_by_extension(filename)
assert skip is True
@pytest.mark.parametrize(
"filename",
[
"report.pdf",
"doc.docx",
"sheet.xlsx",
"slides.pptx",
"readme.txt",
"data.csv",
"photo.png",
"notes.md",
],
)
def test_universal_extensions_are_not_skipped(filename, mocker):
"""Files supported by all parsers (or handled by plaintext/direct_convert) are never skipped."""
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
mocker.patch("app.config.config.ETL_SERVICE", service)
skip, ext = should_skip_by_extension(filename)
assert skip is False, f"{filename} should NOT be skipped with {service}"
assert ext is None
@pytest.mark.parametrize(
"filename,service,expected_skip",
[
("macro.docm", "DOCLING", True),
("macro.docm", "LLAMACLOUD", False),
("mail.eml", "DOCLING", True),
("mail.eml", "UNSTRUCTURED", False),
("photo.gif", "DOCLING", True),
("photo.gif", "LLAMACLOUD", False),
("photo.heic", "UNSTRUCTURED", False),
("photo.heic", "DOCLING", True),
],
)
def test_parser_specific_extensions(filename, service, expected_skip, mocker):
mocker.patch("app.config.config.ETL_SERVICE", service)
skip, ext = should_skip_by_extension(filename)
assert skip is expected_skip, (
f"{filename} with {service}: expected skip={expected_skip}"
)
if expected_skip:
assert ext is not None, "unsupported extension should be returned"
else:
assert ext is None
def test_returns_unsupported_extension(mocker):
"""When a file is skipped, the unsupported extension string is returned."""
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
skip, ext = should_skip_by_extension("macro.docm")
assert skip is True
assert ext == ".docm"

View file

@ -0,0 +1,118 @@
"""Tests for OneDrive file type filtering."""
import pytest
from app.connectors.onedrive.file_types import should_skip_file
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Structural skips (independent of ETL service)
# ---------------------------------------------------------------------------
def test_folder_is_skipped():
item = {"folder": {}, "name": "My Folder"}
skip, ext = should_skip_file(item)
assert skip is True
assert ext is None
def test_remote_item_is_skipped():
item = {"remoteItem": {}, "name": "shared.docx"}
skip, ext = should_skip_file(item)
assert skip is True
assert ext is None
def test_package_is_skipped():
item = {"package": {}, "name": "notebook"}
skip, ext = should_skip_file(item)
assert skip is True
assert ext is None
def test_onenote_is_skipped():
item = {"name": "notes", "file": {"mimeType": "application/msonenote"}}
skip, ext = should_skip_file(item)
assert skip is True
assert ext is None
# ---------------------------------------------------------------------------
# Extension-based skips (require ETL service context)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"filename",
[
"malware.exe",
"archive.zip",
"video.mov",
"font.woff2",
"model.blend",
],
)
def test_unsupported_extensions_are_skipped(filename, mocker):
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
item = {"name": filename, "file": {"mimeType": "application/octet-stream"}}
skip, ext = should_skip_file(item)
assert skip is True, f"{filename} should be skipped"
assert ext is not None
@pytest.mark.parametrize(
"filename",
[
"report.pdf",
"doc.docx",
"sheet.xlsx",
"slides.pptx",
"readme.txt",
"data.csv",
"photo.png",
"notes.md",
],
)
def test_universal_files_are_not_skipped(filename, mocker):
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
mocker.patch("app.config.config.ETL_SERVICE", service)
item = {"name": filename, "file": {"mimeType": "application/octet-stream"}}
skip, ext = should_skip_file(item)
assert skip is False, f"{filename} should NOT be skipped with {service}"
assert ext is None
@pytest.mark.parametrize(
"filename,service,expected_skip",
[
("macro.docm", "DOCLING", True),
("macro.docm", "LLAMACLOUD", False),
("mail.eml", "DOCLING", True),
("mail.eml", "UNSTRUCTURED", False),
("photo.heic", "UNSTRUCTURED", False),
("photo.heic", "DOCLING", True),
],
)
def test_parser_specific_extensions(filename, service, expected_skip, mocker):
mocker.patch("app.config.config.ETL_SERVICE", service)
item = {"name": filename, "file": {"mimeType": "application/octet-stream"}}
skip, ext = should_skip_file(item)
assert skip is expected_skip, (
f"{filename} with {service}: expected skip={expected_skip}"
)
if expected_skip:
assert ext is not None
else:
assert ext is None
def test_returns_unsupported_extension(mocker):
"""When a file is skipped due to unsupported extension, the ext string is returned."""
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
item = {"name": "mail.eml", "file": {"mimeType": "application/octet-stream"}}
skip, ext = should_skip_file(item)
assert skip is True
assert ext == ".eml"

View file

@ -0,0 +1,27 @@
"""Pre-register the etl_pipeline package to avoid circular imports during unit tests."""
import sys
import types
from pathlib import Path
_BACKEND = Path(__file__).resolve().parents[3]
def _stub_package(dotted: str, fs_dir: Path) -> None:
if dotted not in sys.modules:
mod = types.ModuleType(dotted)
mod.__path__ = [str(fs_dir)]
mod.__package__ = dotted
sys.modules[dotted] = mod
parts = dotted.split(".")
if len(parts) > 1:
parent_dotted = ".".join(parts[:-1])
parent = sys.modules.get(parent_dotted)
if parent is not None:
setattr(parent, parts[-1], sys.modules[dotted])
_stub_package("app", _BACKEND / "app")
_stub_package("app.etl_pipeline", _BACKEND / "app" / "etl_pipeline")
_stub_package("app.etl_pipeline.parsers", _BACKEND / "app" / "etl_pipeline" / "parsers")

View file

@ -0,0 +1,461 @@
"""Tests for EtlPipelineService -- the unified ETL pipeline public interface."""
import pytest
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
pytestmark = pytest.mark.unit
async def test_extract_txt_file_returns_markdown(tmp_path):
"""Tracer bullet: a .txt file is read and returned as-is in an EtlResult."""
txt_file = tmp_path / "hello.txt"
txt_file.write_text("Hello, world!", encoding="utf-8")
service = EtlPipelineService()
result = await service.extract(
EtlRequest(file_path=str(txt_file), filename="hello.txt")
)
assert result.markdown_content == "Hello, world!"
assert result.etl_service == "PLAINTEXT"
assert result.content_type == "plaintext"
async def test_extract_md_file(tmp_path):
"""A .md file is classified as PLAINTEXT and extracted."""
md_file = tmp_path / "readme.md"
md_file.write_text("# Title\n\nBody text.", encoding="utf-8")
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(md_file), filename="readme.md")
)
assert result.markdown_content == "# Title\n\nBody text."
assert result.etl_service == "PLAINTEXT"
assert result.content_type == "plaintext"
async def test_extract_markdown_file(tmp_path):
"""A .markdown file is classified as PLAINTEXT and extracted."""
md_file = tmp_path / "notes.markdown"
md_file.write_text("Some notes.", encoding="utf-8")
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(md_file), filename="notes.markdown")
)
assert result.markdown_content == "Some notes."
assert result.etl_service == "PLAINTEXT"
async def test_extract_python_file(tmp_path):
"""A .py source code file is classified as PLAINTEXT."""
py_file = tmp_path / "script.py"
py_file.write_text("print('hello')", encoding="utf-8")
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(py_file), filename="script.py")
)
assert result.markdown_content == "print('hello')"
assert result.etl_service == "PLAINTEXT"
assert result.content_type == "plaintext"
async def test_extract_js_file(tmp_path):
"""A .js source code file is classified as PLAINTEXT."""
js_file = tmp_path / "app.js"
js_file.write_text("console.log('hi');", encoding="utf-8")
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(js_file), filename="app.js")
)
assert result.markdown_content == "console.log('hi');"
assert result.etl_service == "PLAINTEXT"
async def test_extract_csv_returns_markdown_table(tmp_path):
"""A .csv file is converted to a markdown table."""
csv_file = tmp_path / "data.csv"
csv_file.write_text("name,age\nAlice,30\nBob,25\n", encoding="utf-8")
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(csv_file), filename="data.csv")
)
assert "| name | age |" in result.markdown_content
assert "| Alice | 30 |" in result.markdown_content
assert result.etl_service == "DIRECT_CONVERT"
assert result.content_type == "direct_convert"
async def test_extract_tsv_returns_markdown_table(tmp_path):
"""A .tsv file is converted to a markdown table."""
tsv_file = tmp_path / "data.tsv"
tsv_file.write_text("x\ty\n1\t2\n", encoding="utf-8")
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(tsv_file), filename="data.tsv")
)
assert "| x | y |" in result.markdown_content
assert result.etl_service == "DIRECT_CONVERT"
async def test_extract_html_returns_markdown(tmp_path):
"""An .html file is converted to markdown."""
html_file = tmp_path / "page.html"
html_file.write_text("<h1>Title</h1><p>Body</p>", encoding="utf-8")
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(html_file), filename="page.html")
)
assert "Title" in result.markdown_content
assert "Body" in result.markdown_content
assert result.etl_service == "DIRECT_CONVERT"
async def test_extract_mp3_returns_transcription(tmp_path, mocker):
"""An .mp3 audio file is transcribed via litellm.atranscription."""
audio_file = tmp_path / "recording.mp3"
audio_file.write_bytes(b"\x00" * 100)
mocker.patch("app.config.config.STT_SERVICE", "openai/whisper-1")
mocker.patch("app.config.config.STT_SERVICE_API_KEY", "fake-key")
mocker.patch("app.config.config.STT_SERVICE_API_BASE", None)
mock_transcription = mocker.patch(
"app.etl_pipeline.parsers.audio.atranscription",
return_value={"text": "Hello from audio"},
)
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(audio_file), filename="recording.mp3")
)
assert "Hello from audio" in result.markdown_content
assert result.etl_service == "AUDIO"
assert result.content_type == "audio"
mock_transcription.assert_called_once()
# ---------------------------------------------------------------------------
# Slice 7 - DOCLING document parsing
# ---------------------------------------------------------------------------
async def test_extract_pdf_with_docling(tmp_path, mocker):
"""A .pdf file with ETL_SERVICE=DOCLING returns parsed markdown."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {"content": "# Parsed PDF"}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
assert result.markdown_content == "# Parsed PDF"
assert result.etl_service == "DOCLING"
assert result.content_type == "document"
# ---------------------------------------------------------------------------
# Slice 8 - UNSTRUCTURED document parsing
# ---------------------------------------------------------------------------
async def test_extract_pdf_with_unstructured(tmp_path, mocker):
"""A .pdf file with ETL_SERVICE=UNSTRUCTURED returns parsed markdown."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
mocker.patch("app.config.config.ETL_SERVICE", "UNSTRUCTURED")
class FakeDoc:
def __init__(self, text):
self.page_content = text
fake_loader_instance = mocker.AsyncMock()
fake_loader_instance.aload.return_value = [
FakeDoc("Page 1 content"),
FakeDoc("Page 2 content"),
]
mocker.patch(
"langchain_unstructured.UnstructuredLoader",
return_value=fake_loader_instance,
)
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
assert "Page 1 content" in result.markdown_content
assert "Page 2 content" in result.markdown_content
assert result.etl_service == "UNSTRUCTURED"
assert result.content_type == "document"
# ---------------------------------------------------------------------------
# Slice 9 - LLAMACLOUD document parsing
# ---------------------------------------------------------------------------
async def test_extract_pdf_with_llamacloud(tmp_path, mocker):
"""A .pdf file with ETL_SERVICE=LLAMACLOUD returns parsed markdown."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
class FakeDoc:
text = "# LlamaCloud parsed"
class FakeJobResult:
pages = []
def get_markdown_documents(self, split_by_page=True):
return [FakeDoc()]
fake_parser = mocker.AsyncMock()
fake_parser.aparse.return_value = FakeJobResult()
mocker.patch(
"llama_cloud_services.LlamaParse",
return_value=fake_parser,
)
mocker.patch(
"llama_cloud_services.parse.utils.ResultType",
mocker.MagicMock(MD="md"),
)
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf", estimated_pages=5)
)
assert result.markdown_content == "# LlamaCloud parsed"
assert result.etl_service == "LLAMACLOUD"
assert result.content_type == "document"
# ---------------------------------------------------------------------------
# Slice 10 - unknown extension falls through to document ETL
# ---------------------------------------------------------------------------
async def test_unknown_extension_uses_document_etl(tmp_path, mocker):
"""An allowlisted document extension (.docx) routes to the document ETL path."""
docx_file = tmp_path / "doc.docx"
docx_file.write_bytes(b"PK fake docx")
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {"content": "Docx content"}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(docx_file), filename="doc.docx")
)
assert result.markdown_content == "Docx content"
assert result.content_type == "document"
# ---------------------------------------------------------------------------
# Slice 11 - EtlRequest validation
# ---------------------------------------------------------------------------
def test_etl_request_requires_filename():
"""EtlRequest rejects missing filename."""
with pytest.raises(ValueError, match="filename must not be empty"):
EtlRequest(file_path="/tmp/some.txt", filename="")
# ---------------------------------------------------------------------------
# Slice 12 - unknown ETL_SERVICE raises EtlServiceUnavailableError
# ---------------------------------------------------------------------------
async def test_unknown_etl_service_raises(tmp_path, mocker):
"""An unknown ETL_SERVICE raises EtlServiceUnavailableError."""
from app.etl_pipeline.exceptions import EtlServiceUnavailableError
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF fake")
mocker.patch("app.config.config.ETL_SERVICE", "NONEXISTENT")
with pytest.raises(EtlServiceUnavailableError, match="Unknown ETL_SERVICE"):
await EtlPipelineService().extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
# ---------------------------------------------------------------------------
# Slice 13 - unsupported file types are rejected before reaching any parser
# ---------------------------------------------------------------------------
def test_unknown_extension_classified_as_unsupported():
"""An unknown extension defaults to UNSUPPORTED (allowlist behaviour)."""
from app.etl_pipeline.file_classifier import FileCategory, classify_file
assert classify_file("random.xyz") == FileCategory.UNSUPPORTED
@pytest.mark.parametrize(
"filename",
[
"malware.exe",
"archive.zip",
"video.mov",
"font.woff2",
"model.blend",
"data.parquet",
"package.deb",
"firmware.bin",
],
)
def test_unsupported_extensions_classified_correctly(filename):
"""Extensions not in any allowlist are classified as UNSUPPORTED."""
from app.etl_pipeline.file_classifier import FileCategory, classify_file
assert classify_file(filename) == FileCategory.UNSUPPORTED
@pytest.mark.parametrize(
"filename,expected",
[
("report.pdf", "document"),
("doc.docx", "document"),
("slides.pptx", "document"),
("sheet.xlsx", "document"),
("photo.png", "document"),
("photo.jpg", "document"),
("book.epub", "document"),
("letter.odt", "document"),
("readme.md", "plaintext"),
("data.csv", "direct_convert"),
],
)
def test_parseable_extensions_classified_correctly(filename, expected):
"""Parseable files are classified into their correct category."""
from app.etl_pipeline.file_classifier import FileCategory, classify_file
result = classify_file(filename)
assert result != FileCategory.UNSUPPORTED
assert result.value == expected
async def test_extract_unsupported_file_raises_error(tmp_path):
"""EtlPipelineService.extract() raises EtlUnsupportedFileError for .exe files."""
from app.etl_pipeline.exceptions import EtlUnsupportedFileError
exe_file = tmp_path / "program.exe"
exe_file.write_bytes(b"\x00" * 10)
with pytest.raises(EtlUnsupportedFileError, match="not supported"):
await EtlPipelineService().extract(
EtlRequest(file_path=str(exe_file), filename="program.exe")
)
async def test_extract_zip_raises_unsupported_error(tmp_path):
"""EtlPipelineService.extract() raises EtlUnsupportedFileError for .zip archives."""
from app.etl_pipeline.exceptions import EtlUnsupportedFileError
zip_file = tmp_path / "archive.zip"
zip_file.write_bytes(b"PK\x03\x04")
with pytest.raises(EtlUnsupportedFileError, match="not supported"):
await EtlPipelineService().extract(
EtlRequest(file_path=str(zip_file), filename="archive.zip")
)
# ---------------------------------------------------------------------------
# Slice 14 - should_skip_for_service (per-parser document filtering)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"filename,etl_service,expected_skip",
[
("file.eml", "DOCLING", True),
("file.eml", "UNSTRUCTURED", False),
("file.docm", "LLAMACLOUD", False),
("file.docm", "DOCLING", True),
("file.txt", "DOCLING", False),
("file.csv", "LLAMACLOUD", False),
("file.mp3", "UNSTRUCTURED", False),
("file.exe", "LLAMACLOUD", True),
("file.pdf", "DOCLING", False),
("file.webp", "DOCLING", False),
("file.webp", "UNSTRUCTURED", True),
("file.gif", "LLAMACLOUD", False),
("file.gif", "DOCLING", True),
("file.heic", "UNSTRUCTURED", False),
("file.heic", "DOCLING", True),
("file.svg", "LLAMACLOUD", False),
("file.svg", "DOCLING", True),
("file.p7s", "UNSTRUCTURED", False),
("file.p7s", "LLAMACLOUD", True),
],
)
def test_should_skip_for_service(filename, etl_service, expected_skip):
from app.etl_pipeline.file_classifier import should_skip_for_service
assert should_skip_for_service(filename, etl_service) is expected_skip, (
f"{filename} with {etl_service}: expected skip={expected_skip}"
)
# ---------------------------------------------------------------------------
# Slice 14b - ETL pipeline rejects per-parser incompatible documents
# ---------------------------------------------------------------------------
async def test_extract_docm_with_docling_raises_unsupported(tmp_path, mocker):
"""Docling cannot parse .docm -- pipeline should reject before dispatching."""
from app.etl_pipeline.exceptions import EtlUnsupportedFileError
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
docm_file = tmp_path / "macro.docm"
docm_file.write_bytes(b"\x00" * 10)
with pytest.raises(EtlUnsupportedFileError, match="not supported by DOCLING"):
await EtlPipelineService().extract(
EtlRequest(file_path=str(docm_file), filename="macro.docm")
)
async def test_extract_eml_with_docling_raises_unsupported(tmp_path, mocker):
"""Docling cannot parse .eml -- pipeline should reject before dispatching."""
from app.etl_pipeline.exceptions import EtlUnsupportedFileError
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
eml_file = tmp_path / "mail.eml"
eml_file.write_bytes(b"From: test@example.com")
with pytest.raises(EtlUnsupportedFileError, match="not supported by DOCLING"):
await EtlPipelineService().extract(
EtlRequest(file_path=str(eml_file), filename="mail.eml")
)

View file

@ -0,0 +1,70 @@
"""Test that DoclingService does NOT restrict allowed_formats, letting Docling
accept all its supported formats (PDF, DOCX, PPTX, XLSX, IMAGE, etc.)."""
from enum import Enum
from unittest.mock import MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
class _FakeInputFormat(Enum):
PDF = "pdf"
IMAGE = "image"
DOCX = "docx"
PPTX = "pptx"
XLSX = "xlsx"
def test_docling_service_does_not_restrict_allowed_formats():
"""DoclingService should NOT pass allowed_formats to DocumentConverter,
so Docling defaults to accepting every InputFormat it supports."""
mock_converter_cls = MagicMock()
mock_backend = MagicMock()
fake_pipeline_options_cls = MagicMock()
fake_pipeline_options = MagicMock()
fake_pipeline_options_cls.return_value = fake_pipeline_options
fake_pdf_format_option_cls = MagicMock()
with patch.dict(
"sys.modules",
{
"docling": MagicMock(),
"docling.backend": MagicMock(),
"docling.backend.pypdfium2_backend": MagicMock(
PyPdfiumDocumentBackend=mock_backend
),
"docling.datamodel": MagicMock(),
"docling.datamodel.base_models": MagicMock(InputFormat=_FakeInputFormat),
"docling.datamodel.pipeline_options": MagicMock(
PdfPipelineOptions=fake_pipeline_options_cls
),
"docling.document_converter": MagicMock(
DocumentConverter=mock_converter_cls,
PdfFormatOption=fake_pdf_format_option_cls,
),
},
):
from importlib import reload
import app.services.docling_service as mod
reload(mod)
mod.DoclingService()
call_kwargs = mock_converter_cls.call_args
assert call_kwargs is not None, "DocumentConverter was never called"
_, kwargs = call_kwargs
assert "allowed_formats" not in kwargs, (
f"allowed_formats should not be passed — let Docling accept all formats. "
f"Got: {kwargs.get('allowed_formats')}"
)
assert _FakeInputFormat.PDF in kwargs.get("format_options", {}), (
"format_options should still configure PDF pipeline options"
)

View file

@ -0,0 +1,154 @@
"""Tests for the DOCUMENT_EXTENSIONS allowlist module."""
import pytest
pytestmark = pytest.mark.unit
def test_pdf_is_supported_document():
from app.utils.file_extensions import is_supported_document_extension
assert is_supported_document_extension("report.pdf") is True
def test_exe_is_not_supported_document():
from app.utils.file_extensions import is_supported_document_extension
assert is_supported_document_extension("malware.exe") is False
@pytest.mark.parametrize(
"filename",
[
"report.pdf",
"doc.docx",
"old.doc",
"sheet.xlsx",
"legacy.xls",
"slides.pptx",
"deck.ppt",
"macro.docm",
"macro.xlsm",
"macro.pptm",
"photo.png",
"photo.jpg",
"photo.jpeg",
"scan.bmp",
"scan.tiff",
"scan.tif",
"photo.webp",
"anim.gif",
"iphone.heic",
"manual.rtf",
"book.epub",
"letter.odt",
"data.ods",
"presentation.odp",
"inbox.eml",
"outlook.msg",
"korean.hwpx",
"korean.hwp",
"template.dot",
"template.dotm",
"template.pot",
"template.potx",
"binary.xlsb",
"workspace.xlw",
"vector.svg",
"signature.p7s",
],
)
def test_document_extensions_are_supported(filename):
from app.utils.file_extensions import is_supported_document_extension
assert is_supported_document_extension(filename) is True, (
f"{filename} should be supported"
)
@pytest.mark.parametrize(
"filename",
[
"malware.exe",
"archive.zip",
"video.mov",
"font.woff2",
"model.blend",
"random.xyz",
"data.parquet",
"package.deb",
],
)
def test_non_document_extensions_are_not_supported(filename):
from app.utils.file_extensions import is_supported_document_extension
assert is_supported_document_extension(filename) is False, (
f"{filename} should NOT be supported"
)
# ---------------------------------------------------------------------------
# Per-parser extension sets
# ---------------------------------------------------------------------------
def test_union_equals_all_three_sets():
from app.utils.file_extensions import (
DOCLING_DOCUMENT_EXTENSIONS,
DOCUMENT_EXTENSIONS,
LLAMAPARSE_DOCUMENT_EXTENSIONS,
UNSTRUCTURED_DOCUMENT_EXTENSIONS,
)
expected = (
DOCLING_DOCUMENT_EXTENSIONS
| LLAMAPARSE_DOCUMENT_EXTENSIONS
| UNSTRUCTURED_DOCUMENT_EXTENSIONS
)
assert expected == DOCUMENT_EXTENSIONS
def test_get_extensions_for_docling():
from app.utils.file_extensions import get_document_extensions_for_service
exts = get_document_extensions_for_service("DOCLING")
assert ".pdf" in exts
assert ".webp" in exts
assert ".docx" in exts
assert ".eml" not in exts
assert ".docm" not in exts
assert ".gif" not in exts
assert ".heic" not in exts
def test_get_extensions_for_llamacloud():
from app.utils.file_extensions import get_document_extensions_for_service
exts = get_document_extensions_for_service("LLAMACLOUD")
assert ".docm" in exts
assert ".gif" in exts
assert ".svg" in exts
assert ".hwp" in exts
assert ".eml" not in exts
assert ".heic" not in exts
def test_get_extensions_for_unstructured():
from app.utils.file_extensions import get_document_extensions_for_service
exts = get_document_extensions_for_service("UNSTRUCTURED")
assert ".eml" in exts
assert ".heic" in exts
assert ".p7s" in exts
assert ".docm" not in exts
assert ".gif" not in exts
assert ".svg" not in exts
def test_get_extensions_for_none_returns_union():
from app.utils.file_extensions import (
DOCUMENT_EXTENSIONS,
get_document_extensions_for_service,
)
assert get_document_extensions_for_service(None) == DOCUMENT_EXTENSIONS

View file

@ -19,6 +19,9 @@ files:
- "!scripts"
- "!release"
extraResources:
- from: assets/
to: assets/
filter: ["*.ico", "*.png", "*.icns"]
- from: ../surfsense_web/.next/standalone/surfsense_web/
to: standalone/
filter:
@ -58,7 +61,7 @@ win:
icon: assets/icon.ico
target:
- target: nsis
arch: [x64, arm64]
arch: [x64]
nsis:
oneClick: false
perMachine: false

View file

@ -4,7 +4,7 @@
"description": "SurfSense Desktop App",
"main": "dist/main.js",
"scripts": {
"dev": "concurrently -k \"pnpm --dir ../surfsense_web dev\" \"wait-on http://localhost:3000 && electron .\"",
"dev": "pnpm build && concurrently -k \"pnpm --dir ../surfsense_web dev\" \"wait-on http://localhost:3000 && electron .\"",
"build": "node scripts/build-electron.mjs",
"pack:dir": "pnpm build && electron-builder --dir --config electron-builder.yml",
"dist": "pnpm build && electron-builder --config electron-builder.yml",

View file

@ -32,4 +32,13 @@ export const IPC_CHANNELS = {
FOLDER_SYNC_ACK_EVENTS: 'folder-sync:ack-events',
BROWSE_FILES: 'browse:files',
READ_LOCAL_FILES: 'browse:read-local-files',
// Auth token sync across windows
GET_AUTH_TOKENS: 'auth:get-tokens',
SET_AUTH_TOKENS: 'auth:set-tokens',
// Keyboard shortcut configuration
GET_SHORTCUTS: 'shortcuts:get',
SET_SHORTCUTS: 'shortcuts:set',
// Active search space
GET_ACTIVE_SEARCH_SPACE: 'search-space:get-active',
SET_ACTIVE_SEARCH_SPACE: 'search-space:set-active',
} as const;

View file

@ -20,6 +20,13 @@ import {
browseFiles,
readLocalFiles,
} from '../modules/folder-watcher';
import { getShortcuts, setShortcuts, type ShortcutConfig } from '../modules/shortcuts';
import { getActiveSearchSpaceId, setActiveSearchSpaceId } from '../modules/active-search-space';
import { reregisterQuickAsk } from '../modules/quick-ask';
import { reregisterAutocomplete } from '../modules/autocomplete';
import { reregisterGeneralAssist } from '../modules/tray';
let authTokens: { bearer: string; refresh: string } | null = null;
export function registerIpcHandlers(): void {
ipcMain.on(IPC_CHANNELS.OPEN_EXTERNAL, (_event, url: string) => {
@ -89,4 +96,28 @@ export function registerIpcHandlers(): void {
ipcMain.handle(IPC_CHANNELS.READ_LOCAL_FILES, (_event, paths: string[]) =>
readLocalFiles(paths)
);
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
authTokens = tokens;
});
ipcMain.handle(IPC_CHANNELS.GET_AUTH_TOKENS, () => {
return authTokens;
});
ipcMain.handle(IPC_CHANNELS.GET_SHORTCUTS, () => getShortcuts());
ipcMain.handle(IPC_CHANNELS.GET_ACTIVE_SEARCH_SPACE, () => getActiveSearchSpaceId());
ipcMain.handle(IPC_CHANNELS.SET_ACTIVE_SEARCH_SPACE, (_event, id: string) =>
setActiveSearchSpaceId(id)
);
ipcMain.handle(IPC_CHANNELS.SET_SHORTCUTS, async (_event, config: Partial<ShortcutConfig>) => {
const updated = await setShortcuts(config);
if (config.generalAssist) await reregisterGeneralAssist();
if (config.quickAsk) await reregisterQuickAsk();
if (config.autocomplete) await reregisterAutocomplete();
return updated;
});
}

View file

@ -1,7 +1,9 @@
import { app, BrowserWindow } from 'electron';
let isQuitting = false;
import { registerGlobalErrorHandlers, showErrorDialog } from './modules/errors';
import { startNextServer } from './modules/server';
import { createMainWindow } from './modules/window';
import { createMainWindow, getMainWindow } from './modules/window';
import { setupDeepLinks, handlePendingDeepLink } from './modules/deep-links';
import { setupAutoUpdater } from './modules/auto-updater';
import { setupMenu } from './modules/menu';
@ -9,6 +11,7 @@ import { registerQuickAsk, unregisterQuickAsk } from './modules/quick-ask';
import { registerAutocomplete, unregisterAutocomplete } from './modules/autocomplete';
import { registerFolderWatcher, unregisterFolderWatcher } from './modules/folder-watcher';
import { registerIpcHandlers } from './ipc/handlers';
import { createTray, destroyTray } from './modules/tray';
registerGlobalErrorHandlers();
@ -28,29 +31,48 @@ app.whenReady().then(async () => {
return;
}
createMainWindow('/dashboard');
registerQuickAsk();
registerAutocomplete();
await createTray();
const win = createMainWindow('/dashboard');
// Minimize to tray instead of closing the app
win.on('close', (e) => {
if (!isQuitting) {
e.preventDefault();
win.hide();
}
});
await registerQuickAsk();
await registerAutocomplete();
registerFolderWatcher();
setupAutoUpdater();
handlePendingDeepLink();
app.on('activate', () => {
if (BrowserWindow.getAllWindows().length === 0) {
const mw = getMainWindow();
if (!mw || mw.isDestroyed()) {
createMainWindow('/dashboard');
} else {
mw.show();
mw.focus();
}
});
});
// Keep running in the background — the tray "Quit" calls app.exit()
app.on('window-all-closed', () => {
if (process.platform !== 'darwin') {
app.quit();
}
// Do nothing: the app stays alive in the tray
});
app.on('before-quit', () => {
isQuitting = true;
});
app.on('will-quit', () => {
unregisterQuickAsk();
unregisterAutocomplete();
unregisterFolderWatcher();
destroyTray();
});

View file

@ -0,0 +1,24 @@
const STORE_KEY = 'activeSearchSpaceId';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let store: any = null;
async function getStore() {
if (!store) {
const { default: Store } = await import('electron-store');
store = new Store({
name: 'active-search-space',
defaults: { [STORE_KEY]: null as string | null },
});
}
return store;
}
export async function getActiveSearchSpaceId(): Promise<string | null> {
const s = await getStore();
return (s.get(STORE_KEY) as string | null) ?? null;
}
export async function setActiveSearchSpaceId(id: string): Promise<void> {
const s = await getStore();
s.set(STORE_KEY, id);
}

View file

@ -2,16 +2,15 @@ import { clipboard, globalShortcut, ipcMain, screen } from 'electron';
import { IPC_CHANNELS } from '../../ipc/channels';
import { getFrontmostApp, getWindowTitle, hasAccessibilityPermission, simulatePaste } from '../platform';
import { hasScreenRecordingPermission, requestAccessibility, requestScreenRecording } from '../permissions';
import { getMainWindow } from '../window';
import { captureScreen } from './screenshot';
import { createSuggestionWindow, destroySuggestion, getSuggestionWindow } from './suggestion-window';
import { getShortcuts } from '../shortcuts';
import { getActiveSearchSpaceId } from '../active-search-space';
const SHORTCUT = 'CommandOrControl+Shift+Space';
let currentShortcut = '';
let autocompleteEnabled = true;
let savedClipboard = '';
let sourceApp = '';
let lastSearchSpaceId: string | null = null;
function isSurfSenseWindow(): boolean {
const app = getFrontmostApp();
@ -37,21 +36,11 @@ async function triggerAutocomplete(): Promise<void> {
return;
}
const mainWin = getMainWindow();
if (mainWin && !mainWin.isDestroyed()) {
const mainUrl = mainWin.webContents.getURL();
const match = mainUrl.match(/\/dashboard\/(\d+)/);
if (match) {
lastSearchSpaceId = match[1];
}
}
if (!lastSearchSpaceId) {
console.warn('[autocomplete] No active search space. Open a search space first.');
const searchSpaceId = await getActiveSearchSpaceId();
if (!searchSpaceId) {
console.warn('[autocomplete] No active search space. Select a search space first.');
return;
}
const searchSpaceId = lastSearchSpaceId;
const cursor = screen.getCursorScreenPoint();
const win = createSuggestionWindow(cursor.x, cursor.y);
@ -91,7 +80,12 @@ async function acceptAndInject(text: string): Promise<void> {
}
}
let ipcRegistered = false;
function registerIpcHandlers(): void {
if (ipcRegistered) return;
ipcRegistered = true;
ipcMain.handle(IPC_CHANNELS.ACCEPT_SUGGESTION, async (_event, text: string) => {
await acceptAndInject(text);
});
@ -107,26 +101,39 @@ function registerIpcHandlers(): void {
ipcMain.handle(IPC_CHANNELS.GET_AUTOCOMPLETE_ENABLED, () => autocompleteEnabled);
}
export function registerAutocomplete(): void {
registerIpcHandlers();
function autocompleteHandler(): void {
const sw = getSuggestionWindow();
if (sw && !sw.isDestroyed()) {
destroySuggestion();
return;
}
triggerAutocomplete();
}
const ok = globalShortcut.register(SHORTCUT, () => {
const sw = getSuggestionWindow();
if (sw && !sw.isDestroyed()) {
destroySuggestion();
return;
}
triggerAutocomplete();
});
async function registerShortcut(): Promise<void> {
const shortcuts = await getShortcuts();
currentShortcut = shortcuts.autocomplete;
const ok = globalShortcut.register(currentShortcut, autocompleteHandler);
if (!ok) {
console.error(`[autocomplete] Failed to register shortcut ${SHORTCUT}`);
console.error(`[autocomplete] Failed to register shortcut ${currentShortcut}`);
} else {
console.log(`[autocomplete] Registered shortcut ${SHORTCUT}`);
console.log(`[autocomplete] Registered shortcut ${currentShortcut}`);
}
}
export async function registerAutocomplete(): Promise<void> {
registerIpcHandlers();
await registerShortcut();
}
export function unregisterAutocomplete(): void {
globalShortcut.unregister(SHORTCUT);
if (currentShortcut) globalShortcut.unregister(currentShortcut);
destroySuggestion();
}
export async function reregisterAutocomplete(): Promise<void> {
unregisterAutocomplete();
await registerShortcut();
}

View file

@ -1,16 +1,20 @@
import { execSync } from 'child_process';
import { systemPreferences } from 'electron';
const EXEC_OPTS = { windowsHide: true } as const;
export function getFrontmostApp(): string {
try {
if (process.platform === 'darwin') {
return execSync(
'osascript -e \'tell application "System Events" to get name of first application process whose frontmost is true\''
'osascript -e \'tell application "System Events" to get name of first application process whose frontmost is true\'',
EXEC_OPTS,
).toString().trim();
}
if (process.platform === 'win32') {
return execSync(
'powershell -command "Add-Type \'using System; using System.Runtime.InteropServices; public class W { [DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow(); }\'; (Get-Process | Where-Object { $_.MainWindowHandle -eq [W]::GetForegroundWindow() }).ProcessName"'
'powershell -NoProfile -NonInteractive -command "Add-Type \'using System; using System.Runtime.InteropServices; public class W { [DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow(); }\'; (Get-Process | Where-Object { $_.MainWindowHandle -eq [W]::GetForegroundWindow() }).ProcessName"',
EXEC_OPTS,
).toString().trim();
}
} catch {
@ -21,9 +25,23 @@ export function getFrontmostApp(): string {
export function simulatePaste(): void {
if (process.platform === 'darwin') {
execSync('osascript -e \'tell application "System Events" to keystroke "v" using command down\'');
execSync('osascript -e \'tell application "System Events" to keystroke "v" using command down\'', EXEC_OPTS);
} else if (process.platform === 'win32') {
execSync('powershell -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^v\')"');
execSync('powershell -NoProfile -NonInteractive -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^v\')"', EXEC_OPTS);
}
}
export function simulateCopy(): boolean {
try {
if (process.platform === 'darwin') {
execSync('osascript -e \'tell application "System Events" to keystroke "c" using command down\'', EXEC_OPTS);
} else if (process.platform === 'win32') {
execSync('powershell -NoProfile -NonInteractive -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^c\')"', EXEC_OPTS);
}
return true;
} catch (err) {
console.error('[simulateCopy] Failed:', err);
return false;
}
}
@ -36,12 +54,14 @@ export function getWindowTitle(): string {
try {
if (process.platform === 'darwin') {
return execSync(
'osascript -e \'tell application "System Events" to get title of front window of first application process whose frontmost is true\''
'osascript -e \'tell application "System Events" to get title of front window of first application process whose frontmost is true\'',
EXEC_OPTS,
).toString().trim();
}
if (process.platform === 'win32') {
return execSync(
'powershell -command "(Get-Process | Where-Object { $_.MainWindowHandle -eq (Add-Type -MemberDefinition \'[DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow();\' -Name W -PassThru)::GetForegroundWindow() }).MainWindowTitle"'
'powershell -NoProfile -NonInteractive -command "(Get-Process | Where-Object { $_.MainWindowHandle -eq (Add-Type -MemberDefinition \'[DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow();\' -Name W -PassThru)::GetForegroundWindow() }).MainWindowTitle"',
EXEC_OPTS,
).toString().trim();
}
} catch {

View file

@ -1,13 +1,16 @@
import { BrowserWindow, clipboard, globalShortcut, ipcMain, screen, shell } from 'electron';
import path from 'path';
import { IPC_CHANNELS } from '../ipc/channels';
import { checkAccessibilityPermission, getFrontmostApp, simulatePaste } from './platform';
import { checkAccessibilityPermission, getFrontmostApp, simulateCopy, simulatePaste } from './platform';
import { getServerPort } from './server';
import { getShortcuts } from './shortcuts';
import { getActiveSearchSpaceId } from './active-search-space';
const SHORTCUT = 'CommandOrControl+Option+S';
let currentShortcut = '';
let quickAskWindow: BrowserWindow | null = null;
let pendingText = '';
let pendingMode = '';
let pendingSearchSpaceId: string | null = null;
let sourceApp = '';
let savedClipboard = '';
@ -52,7 +55,9 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow {
skipTaskbar: true,
});
quickAskWindow.loadURL(`http://localhost:${getServerPort()}/dashboard`);
const spaceId = pendingSearchSpaceId;
const route = spaceId ? `/dashboard/${spaceId}/new-chat` : '/dashboard';
quickAskWindow.loadURL(`http://localhost:${getServerPort()}${route}`);
quickAskWindow.once('ready-to-show', () => {
quickAskWindow?.show();
@ -77,29 +82,53 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow {
return quickAskWindow;
}
export function registerQuickAsk(): void {
const ok = globalShortcut.register(SHORTCUT, () => {
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
destroyQuickAsk();
return;
}
async function openQuickAsk(text: string): Promise<void> {
pendingText = text;
pendingSearchSpaceId = await getActiveSearchSpaceId();
const cursor = screen.getCursorScreenPoint();
const pos = clampToScreen(cursor.x, cursor.y, 450, 750);
createQuickAskWindow(pos.x, pos.y);
}
sourceApp = getFrontmostApp();
savedClipboard = clipboard.readText();
async function quickAskHandler(): Promise<void> {
console.log('[quick-ask] Handler triggered');
const text = savedClipboard.trim();
if (!text) return;
pendingText = text;
const cursor = screen.getCursorScreenPoint();
const pos = clampToScreen(cursor.x, cursor.y, 450, 750);
createQuickAskWindow(pos.x, pos.y);
});
if (!ok) {
console.log(`Quick-ask: failed to register ${SHORTCUT}`);
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
console.log('[quick-ask] Window already open, closing');
destroyQuickAsk();
return;
}
if (!checkAccessibilityPermission()) {
console.log('[quick-ask] Accessibility permission denied');
return;
}
savedClipboard = clipboard.readText();
console.log('[quick-ask] Saved clipboard length:', savedClipboard.length);
const copyOk = simulateCopy();
console.log('[quick-ask] simulateCopy result:', copyOk);
await new Promise((r) => setTimeout(r, 300));
const afterCopy = clipboard.readText();
const selected = afterCopy.trim();
console.log('[quick-ask] Clipboard after copy length:', afterCopy.length, 'changed:', afterCopy !== savedClipboard);
const text = selected || savedClipboard.trim();
sourceApp = getFrontmostApp();
console.log('[quick-ask] Source app:', sourceApp, '| Opening Quick Assist with', text.length, 'chars', selected ? '(selected)' : text ? '(clipboard fallback)' : '(empty)');
openQuickAsk(text);
}
let ipcRegistered = false;
function registerIpcHandlers(): void {
if (ipcRegistered) return;
ipcRegistered = true;
ipcMain.handle(IPC_CHANNELS.QUICK_ASK_TEXT, () => {
const text = pendingText;
pendingText = '';
@ -136,6 +165,24 @@ export function registerQuickAsk(): void {
});
}
export function unregisterQuickAsk(): void {
globalShortcut.unregister(SHORTCUT);
async function registerShortcut(): Promise<void> {
const shortcuts = await getShortcuts();
currentShortcut = shortcuts.quickAsk;
const ok = globalShortcut.register(currentShortcut, () => { quickAskHandler(); });
console.log(`[quick-ask] Register ${currentShortcut}: ${ok ? 'OK' : 'FAILED'}`);
}
export async function registerQuickAsk(): Promise<void> {
registerIpcHandlers();
await registerShortcut();
}
export function unregisterQuickAsk(): void {
if (currentShortcut) globalShortcut.unregister(currentShortcut);
}
export async function reregisterQuickAsk(): Promise<void> {
unregisterQuickAsk();
await registerShortcut();
}

View file

@ -0,0 +1,44 @@
export interface ShortcutConfig {
generalAssist: string;
quickAsk: string;
autocomplete: string;
}
const DEFAULTS: ShortcutConfig = {
generalAssist: 'CommandOrControl+Shift+S',
quickAsk: 'CommandOrControl+Alt+S',
autocomplete: 'CommandOrControl+Shift+Space',
};
const STORE_KEY = 'shortcuts';
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- lazily imported ESM module; matches folder-watcher.ts pattern
let store: any = null;
async function getStore() {
if (!store) {
const { default: Store } = await import('electron-store');
store = new Store({
name: 'keyboard-shortcuts',
defaults: { [STORE_KEY]: DEFAULTS },
});
}
return store;
}
export async function getShortcuts(): Promise<ShortcutConfig> {
const s = await getStore();
const stored = s.get(STORE_KEY) as Partial<ShortcutConfig> | undefined;
return { ...DEFAULTS, ...stored };
}
export async function setShortcuts(config: Partial<ShortcutConfig>): Promise<ShortcutConfig> {
const s = await getStore();
const current = (s.get(STORE_KEY) as ShortcutConfig) ?? DEFAULTS;
const merged = { ...current, ...config };
s.set(STORE_KEY, merged);
return merged;
}
export function getDefaults(): ShortcutConfig {
return { ...DEFAULTS };
}

View file

@ -0,0 +1,77 @@
import { app, globalShortcut, Menu, nativeImage, Tray } from 'electron';
import path from 'path';
import { getMainWindow, createMainWindow } from './window';
import { getShortcuts } from './shortcuts';
let tray: Tray | null = null;
let currentShortcut: string | null = null;
function getTrayIcon(): nativeImage {
const iconName = process.platform === 'win32' ? 'icon.ico' : 'icon.png';
const iconPath = app.isPackaged
? path.join(process.resourcesPath, 'assets', iconName)
: path.join(__dirname, '..', 'assets', iconName);
const img = nativeImage.createFromPath(iconPath);
return img.resize({ width: 16, height: 16 });
}
function showMainWindow(): void {
let win = getMainWindow();
if (!win || win.isDestroyed()) {
win = createMainWindow('/dashboard');
} else {
win.show();
win.focus();
}
}
function registerShortcut(accelerator: string): void {
if (currentShortcut) {
globalShortcut.unregister(currentShortcut);
currentShortcut = null;
}
if (!accelerator) return;
try {
const ok = globalShortcut.register(accelerator, showMainWindow);
if (ok) {
currentShortcut = accelerator;
} else {
console.warn(`[tray] Failed to register General Assist shortcut: ${accelerator}`);
}
} catch (err) {
console.error(`[tray] Error registering General Assist shortcut:`, err);
}
}
export async function createTray(): Promise<void> {
if (tray) return;
tray = new Tray(getTrayIcon());
tray.setToolTip('SurfSense');
const contextMenu = Menu.buildFromTemplate([
{ label: 'Open SurfSense', click: showMainWindow },
{ type: 'separator' },
{ label: 'Quit', click: () => { app.exit(0); } },
]);
tray.setContextMenu(contextMenu);
tray.on('double-click', showMainWindow);
const shortcuts = await getShortcuts();
registerShortcut(shortcuts.generalAssist);
}
export async function reregisterGeneralAssist(): Promise<void> {
const shortcuts = await getShortcuts();
registerShortcut(shortcuts.generalAssist);
}
export function destroyTray(): void {
if (currentShortcut) {
globalShortcut.unregister(currentShortcut);
currentShortcut = null;
}
tray?.destroy();
tray = null;
}

View file

@ -2,6 +2,7 @@ import { app, BrowserWindow, shell, session } from 'electron';
import path from 'path';
import { showErrorDialog } from './errors';
import { getServerPort } from './server';
import { setActiveSearchSpaceId } from './active-search-space';
const isDev = !app.isPackaged;
const HOSTED_FRONTEND_URL = process.env.HOSTED_FRONTEND_URL as string;
@ -55,6 +56,16 @@ export function createMainWindow(initialPath = '/dashboard'): BrowserWindow {
showErrorDialog('Page failed to load', new Error(`${errorDescription} (${errorCode})\n${validatedURL}`));
});
// Auto-sync active search space from URL navigation
const syncSearchSpace = (url: string) => {
const match = url.match(/\/dashboard\/(\d+)/);
if (match) {
setActiveSearchSpaceId(match[1]);
}
};
mainWindow.webContents.on('did-navigate', (_event, url) => syncSearchSpace(url));
mainWindow.webContents.on('did-navigate-in-page', (_event, url) => syncSearchSpace(url));
if (isDev) {
mainWindow.webContents.openDevTools();
}

View file

@ -68,4 +68,19 @@ contextBridge.exposeInMainWorld('electronAPI', {
// Browse files via native dialog
browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES),
readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths),
// Auth token sync across windows
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
setAuthTokens: (bearer: string, refresh: string) =>
ipcRenderer.invoke(IPC_CHANNELS.SET_AUTH_TOKENS, { bearer, refresh }),
// Keyboard shortcut configuration
getShortcuts: () => ipcRenderer.invoke(IPC_CHANNELS.GET_SHORTCUTS),
setShortcuts: (config: Record<string, string>) =>
ipcRenderer.invoke(IPC_CHANNELS.SET_SHORTCUTS, config),
// Active search space
getActiveSearchSpace: () => ipcRenderer.invoke(IPC_CHANNELS.GET_ACTIVE_SEARCH_SPACE),
setActiveSearchSpace: (id: string) =>
ipcRenderer.invoke(IPC_CHANNELS.SET_ACTIVE_SEARCH_SPACE, id),
});

View file

@ -19,6 +19,7 @@ import { OnboardingTour } from "@/components/onboarding-tour";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { useFolderSync } from "@/hooks/use-folder-sync";
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
import { useElectronAPI } from "@/hooks/use-platform";
export function DashboardClientLayout({
children,
@ -139,6 +140,8 @@ export function DashboardClientLayout({
refetchPreferences,
]);
const electronAPI = useElectronAPI();
useEffect(() => {
const activeSeacrhSpaceId =
typeof search_space_id === "string"
@ -148,7 +151,16 @@ export function DashboardClientLayout({
: "";
if (!activeSeacrhSpaceId) return;
setActiveSearchSpaceIdState(activeSeacrhSpaceId);
}, [search_space_id, setActiveSearchSpaceIdState]);
// Sync to Electron store if stored value is null (first navigation)
if (electronAPI?.setActiveSearchSpace) {
electronAPI.getActiveSearchSpace?.().then((stored) => {
if (!stored) {
electronAPI.setActiveSearchSpace!(activeSeacrhSpaceId);
}
}).catch(() => {});
}
}, [search_space_id, setActiveSearchSpaceIdState, electronAPI]);
// Determine if we should show loading
const shouldShowLoading =

View file

@ -8,6 +8,7 @@ import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import { Input } from "@/components/ui/input";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { getDocumentTypeIcon, getDocumentTypeLabel } from "./DocumentTypeIcon";
@ -63,109 +64,113 @@ export function DocumentsFilters({
return (
<div className="flex select-none">
<div className="flex items-center gap-2 w-full">
{/* Type Filter */}
<Popover>
<PopoverTrigger asChild>
<Button
variant="outline"
size="icon"
className="h-9 w-9 shrink-0 border-dashed border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar"
>
<ListFilter size={14} />
{activeTypes.length > 0 && (
<span className="absolute -top-1 -right-1 flex h-4 w-4 items-center justify-center rounded-full bg-primary text-[9px] font-medium text-primary-foreground">
{activeTypes.length}
</span>
)}
</Button>
</PopoverTrigger>
<PopoverContent className="w-56 md:w-52 !p-0 overflow-hidden" align="end">
<div>
{/* Search input */}
<div className="p-2">
<div className="relative">
<Search className="absolute left-0.5 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
placeholder="Search types"
value={typeSearchQuery}
onChange={(e) => setTypeSearchQuery(e.target.value)}
className="h-6 pl-6 text-sm bg-transparent border-0 shadow-none"
/>
</div>
</div>
{/* Filter + New Folder Toggle Group */}
<ToggleGroup type="multiple" variant="outline" value={[]} className="overflow-visible">
{onCreateFolder && (
<Tooltip>
<TooltipTrigger asChild>
<ToggleGroupItem
value="folder"
className="h-9 w-9 shrink-0 border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar"
onClick={(e) => {
e.preventDefault();
onCreateFolder();
}}
>
<FolderPlus size={14} />
</ToggleGroupItem>
</TooltipTrigger>
<TooltipContent>New folder</TooltipContent>
</Tooltip>
)}
<div
className="max-h-[300px] overflow-y-auto overflow-x-hidden py-1.5 px-1.5"
onScroll={handleScroll}
style={{
maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`,
WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`,
}}
>
{filteredTypes.length === 0 ? (
<div className="py-6 text-center text-sm text-muted-foreground">
No types found
</div>
) : (
filteredTypes.map((value: DocumentTypeEnum, i) => (
<div
role="option"
aria-selected={activeTypes.includes(value)}
tabIndex={0}
key={value}
className="flex w-full items-center gap-2.5 py-2 px-3 rounded-md hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-colors cursor-pointer text-left"
onClick={() => onToggleType(value, !activeTypes.includes(value))}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
onToggleType(value, !activeTypes.includes(value));
}
}}
>
{/* Icon */}
<div className="flex h-7 w-7 shrink-0 items-center justify-center rounded-md bg-muted/50 text-foreground/80">
{getDocumentTypeIcon(value, "h-4 w-4")}
</div>
{/* Text content */}
<div className="flex flex-col min-w-0 flex-1 gap-0.5">
<span className="text-[13px] font-medium text-foreground truncate leading-tight">
{getDocumentTypeLabel(value)}
</span>
<span className="text-[11px] text-muted-foreground leading-tight">
{typeCounts.get(value)} document
{(typeCounts.get(value) ?? 0) !== 1 ? "s" : ""}
</span>
</div>
{/* Checkbox */}
<Checkbox
id={`${id}-${i}`}
checked={activeTypes.includes(value)}
onCheckedChange={(checked: boolean) => onToggleType(value, !!checked)}
className="h-4 w-4 shrink-0 rounded border-muted-foreground/30 data-[state=checked]:bg-primary data-[state=checked]:border-primary"
/>
</div>
))
)}
</div>
{activeTypes.length > 0 && (
<div className="px-3 pt-1.5 pb-1.5 border-t border-border dark:border-neutral-700">
<Button
variant="ghost"
size="sm"
className="w-full h-7 text-[11px] text-muted-foreground hover:text-foreground hover:bg-neutral-200 dark:hover:bg-neutral-700"
onClick={() => {
activeTypes.forEach((t) => {
onToggleType(t, false);
});
}}
<Popover>
<Tooltip>
<TooltipTrigger asChild>
<PopoverTrigger asChild>
<ToggleGroupItem
value="filter"
className="relative h-9 w-9 shrink-0 border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar overflow-visible"
>
Clear filters
</Button>
<ListFilter size={14} />
{activeTypes.length > 0 && (
<span className="absolute -top-1 -right-1 flex h-4 w-4 items-center justify-center rounded-full bg-sidebar-border text-[9px] font-medium text-sidebar-foreground">
{activeTypes.length}
</span>
)}
</ToggleGroupItem>
</PopoverTrigger>
</TooltipTrigger>
<TooltipContent>Filter by type</TooltipContent>
</Tooltip>
<PopoverContent className="w-56 md:w-52 !p-0 overflow-hidden" align="start">
<div>
<div className="p-2">
<div className="relative">
<Search className="absolute left-0.5 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
placeholder="Search types"
value={typeSearchQuery}
onChange={(e) => setTypeSearchQuery(e.target.value)}
className="h-6 pl-6 text-sm bg-transparent border-0 shadow-none"
/>
</div>
</div>
)}
</div>
</PopoverContent>
</Popover>
<div
className="max-h-[300px] overflow-y-auto overflow-x-hidden py-1.5 px-1.5"
onScroll={handleScroll}
style={{
maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`,
WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`,
}}
>
{filteredTypes.length === 0 ? (
<div className="py-6 text-center text-sm text-muted-foreground">
No types found
</div>
) : (
filteredTypes.map((value: DocumentTypeEnum, i) => (
<div
role="option"
aria-selected={activeTypes.includes(value)}
tabIndex={0}
key={value}
className="flex w-full items-center gap-2.5 py-2 px-3 rounded-md hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-colors cursor-pointer text-left"
onClick={() => onToggleType(value, !activeTypes.includes(value))}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
onToggleType(value, !activeTypes.includes(value));
}
}}
>
<div className="flex h-7 w-7 shrink-0 items-center justify-center rounded-md bg-muted/50 text-foreground/80">
{getDocumentTypeIcon(value, "h-4 w-4")}
</div>
<div className="flex flex-col min-w-0 flex-1 gap-0.5">
<span className="text-[13px] font-medium text-foreground truncate leading-tight">
{getDocumentTypeLabel(value)}
</span>
<span className="text-[11px] text-muted-foreground leading-tight">
{typeCounts.get(value)} document
{(typeCounts.get(value) ?? 0) !== 1 ? "s" : ""}
</span>
</div>
<Checkbox
id={`${id}-${i}`}
checked={activeTypes.includes(value)}
onCheckedChange={(checked: boolean) => onToggleType(value, !!checked)}
className="h-4 w-4 shrink-0 rounded border-muted-foreground/30 data-[state=checked]:bg-primary data-[state=checked]:border-primary"
/>
</div>
))
)}
</div>
</div>
</PopoverContent>
</Popover>
</ToggleGroup>
{/* Search Input */}
<div className="relative flex-1 min-w-0">
@ -197,23 +202,6 @@ export function DocumentsFilters({
)}
</div>
{/* New Folder Button */}
{onCreateFolder && (
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
className="h-9 w-9 shrink-0 border-dashed border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar"
onClick={onCreateFolder}
>
<FolderPlus size={14} />
</Button>
</TooltipTrigger>
<TooltipContent>New folder</TooltipContent>
</Tooltip>
)}
{/* Upload Button */}
<Button
data-joyride="upload-button"

View file

@ -1,30 +1,65 @@
"use client";
import { BrainCog, Rocket, Zap } from "lucide-react";
import { useEffect, useState } from "react";
import { toast } from "sonner";
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Label } from "@/components/ui/label";
import { Switch } from "@/components/ui/switch";
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
import { Spinner } from "@/components/ui/spinner";
import { Switch } from "@/components/ui/switch";
import { useElectronAPI } from "@/hooks/use-platform";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import type { SearchSpace } from "@/contracts/types/search-space.types";
export function DesktopContent() {
const [isElectron, setIsElectron] = useState(false);
const api = useElectronAPI();
const [loading, setLoading] = useState(true);
const [enabled, setEnabled] = useState(true);
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
const [searchSpaces, setSearchSpaces] = useState<SearchSpace[]>([]);
const [activeSpaceId, setActiveSpaceId] = useState<string | null>(null);
useEffect(() => {
if (!window.electronAPI) {
if (!api) {
setLoading(false);
setShortcutsLoaded(true);
return;
}
setIsElectron(true);
window.electronAPI.getAutocompleteEnabled().then((val) => {
setEnabled(val);
setLoading(false);
});
}, []);
let mounted = true;
if (!isElectron) {
Promise.all([
api.getAutocompleteEnabled(),
api.getShortcuts?.() ?? Promise.resolve(null),
api.getActiveSearchSpace?.() ?? Promise.resolve(null),
searchSpacesApiService.getSearchSpaces(),
])
.then(([autoEnabled, config, spaceId, spaces]) => {
if (!mounted) return;
setEnabled(autoEnabled);
if (config) setShortcuts(config);
setActiveSpaceId(spaceId);
if (spaces) setSearchSpaces(spaces);
setLoading(false);
setShortcutsLoaded(true);
})
.catch(() => {
if (!mounted) return;
setLoading(false);
setShortcutsLoaded(true);
});
return () => {
mounted = false;
};
}, [api]);
if (!api) {
return (
<div className="flex flex-col items-center justify-center py-12 text-center">
<p className="text-sm text-muted-foreground">
@ -44,14 +79,114 @@ export function DesktopContent() {
const handleToggle = async (checked: boolean) => {
setEnabled(checked);
await window.electronAPI!.setAutocompleteEnabled(checked);
await api.setAutocompleteEnabled(checked);
};
const updateShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete", accelerator: string) => {
setShortcuts((prev) => {
const updated = { ...prev, [key]: accelerator };
api.setShortcuts?.({ [key]: accelerator }).catch(() => {
toast.error("Failed to update shortcut");
});
return updated;
});
toast.success("Shortcut updated");
};
const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => {
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
};
const handleSearchSpaceChange = (value: string) => {
setActiveSpaceId(value);
api.setActiveSearchSpace?.(value);
toast.success("Default search space updated");
};
return (
<div className="space-y-4 md:space-y-6">
{/* Default Search Space */}
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
<CardTitle className="text-base md:text-lg">Autocomplete</CardTitle>
<CardTitle className="text-base md:text-lg">Default Search Space</CardTitle>
<CardDescription className="text-xs md:text-sm">
Choose which search space General Assist, Quick Assist, and Extreme Assist operate against.
</CardDescription>
</CardHeader>
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
{searchSpaces.length > 0 ? (
<Select value={activeSpaceId ?? undefined} onValueChange={handleSearchSpaceChange}>
<SelectTrigger className="w-full">
<SelectValue placeholder="Select a search space" />
</SelectTrigger>
<SelectContent>
{searchSpaces.map((space) => (
<SelectItem key={space.id} value={String(space.id)}>
{space.name}
</SelectItem>
))}
</SelectContent>
</Select>
) : (
<p className="text-sm text-muted-foreground">No search spaces found. Create one first.</p>
)}
</CardContent>
</Card>
{/* Keyboard Shortcuts */}
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
<CardTitle className="text-base md:text-lg">Keyboard Shortcuts</CardTitle>
<CardDescription className="text-xs md:text-sm">
Customize the global keyboard shortcuts for desktop features.
</CardDescription>
</CardHeader>
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
{shortcutsLoaded ? (
<div className="flex flex-col gap-3">
<ShortcutRecorder
value={shortcuts.generalAssist}
onChange={(accel) => updateShortcut("generalAssist", accel)}
onReset={() => resetShortcut("generalAssist")}
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
label="General Assist"
description="Launch SurfSense instantly from any application"
icon={Rocket}
/>
<ShortcutRecorder
value={shortcuts.quickAsk}
onChange={(accel) => updateShortcut("quickAsk", accel)}
onReset={() => resetShortcut("quickAsk")}
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
label="Quick Assist"
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
icon={Zap}
/>
<ShortcutRecorder
value={shortcuts.autocomplete}
onChange={(accel) => updateShortcut("autocomplete", accel)}
onReset={() => resetShortcut("autocomplete")}
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
label="Extreme Assist"
description="AI drafts text using your screen context and knowledge base"
icon={BrainCog}
/>
<p className="text-[11px] text-muted-foreground">
Click a shortcut and press a new key combination to change it.
</p>
</div>
) : (
<div className="flex justify-center py-4">
<Spinner size="sm" />
</div>
)}
</CardContent>
</Card>
{/* Extreme Assist Toggle */}
<Card>
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
<CardTitle className="text-base md:text-lg">Extreme Assist</CardTitle>
<CardDescription className="text-xs md:text-sm">
Get inline writing suggestions powered by your knowledge base as you type in any app.
</CardDescription>
@ -60,17 +195,13 @@ export function DesktopContent() {
<div className="flex items-center justify-between rounded-lg border p-4">
<div className="space-y-0.5">
<Label htmlFor="autocomplete-toggle" className="text-sm font-medium cursor-pointer">
Enable autocomplete
Enable Extreme Assist
</Label>
<p className="text-xs text-muted-foreground">
Show suggestions while typing in other applications.
</p>
</div>
<Switch
id="autocomplete-toggle"
checked={enabled}
onCheckedChange={handleToggle}
/>
<Switch id="autocomplete-toggle" checked={enabled} onCheckedChange={handleToggle} />
</div>
</CardContent>
</Card>

View file

@ -3,7 +3,7 @@
import { useEffect, useState } from "react";
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
import { getBearerToken, redirectToLogin } from "@/lib/auth-utils";
import { ensureTokensFromElectron, getBearerToken, redirectToLogin } from "@/lib/auth-utils";
import { queryClient } from "@/lib/query-client/client";
interface DashboardLayoutProps {
@ -17,15 +17,20 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
useGlobalLoadingEffect(isCheckingAuth);
useEffect(() => {
// Check if user is authenticated
const token = getBearerToken();
if (!token) {
// Save current path and redirect to login
redirectToLogin();
return;
async function checkAuth() {
let token = getBearerToken();
if (!token) {
const synced = await ensureTokensFromElectron();
if (synced) token = getBearerToken();
}
if (!token) {
redirectToLogin();
return;
}
queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] });
setIsCheckingAuth(false);
}
queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] });
setIsCheckingAuth(false);
checkAuth();
}, []);
// Return null while loading - the global provider handles the loading UI

View file

@ -0,0 +1,284 @@
"use client";
import { IconBrandGoogleFilled } from "@tabler/icons-react";
import { useAtom } from "jotai";
import { BrainCog, Eye, EyeOff, Rocket, Zap } from "lucide-react";
import Image from "next/image";
import { useRouter } from "next/navigation";
import { useCallback, useEffect, useState } from "react";
import { toast } from "sonner";
import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Separator } from "@/components/ui/separator";
import { Spinner } from "@/components/ui/spinner";
import { useElectronAPI } from "@/hooks/use-platform";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { setBearerToken } from "@/lib/auth-utils";
import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config";
const isGoogleAuth = AUTH_TYPE === "GOOGLE";
export default function DesktopLoginPage() {
const router = useRouter();
const api = useElectronAPI();
const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom);
const [email, setEmail] = useState("");
const [password, setPassword] = useState("");
const [showPassword, setShowPassword] = useState(false);
const [loginError, setLoginError] = useState<string | null>(null);
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
useEffect(() => {
if (!api?.getShortcuts) {
setShortcutsLoaded(true);
return;
}
api
.getShortcuts()
.then((config) => {
if (config) setShortcuts(config);
setShortcutsLoaded(true);
})
.catch(() => setShortcutsLoaded(true));
}, [api]);
const updateShortcut = useCallback(
(key: "generalAssist" | "quickAsk" | "autocomplete", accelerator: string) => {
setShortcuts((prev) => {
const updated = { ...prev, [key]: accelerator };
api?.setShortcuts?.({ [key]: accelerator }).catch(() => {
toast.error("Failed to update shortcut");
});
return updated;
});
toast.success("Shortcut updated");
},
[api]
);
const resetShortcut = useCallback(
(key: "generalAssist" | "quickAsk" | "autocomplete") => {
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
},
[updateShortcut]
);
const handleGoogleLogin = () => {
window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`;
};
const autoSetSearchSpace = async () => {
try {
const stored = await api?.getActiveSearchSpace?.();
if (stored) return;
const spaces = await searchSpacesApiService.getSearchSpaces();
if (spaces?.length) {
await api?.setActiveSearchSpace?.(String(spaces[0].id));
}
} catch {
// non-critical — dashboard-sync will catch it later
}
};
const handleLocalLogin = async (e: React.FormEvent) => {
e.preventDefault();
setLoginError(null);
try {
const data = await login({
username: email,
password,
grant_type: "password",
});
if (typeof window !== "undefined") {
sessionStorage.setItem("login_success_tracked", "true");
}
setBearerToken(data.access_token);
await autoSetSearchSpace();
setTimeout(() => {
router.push(`/auth/callback?token=${data.access_token}`);
}, 300);
} catch (err) {
if (err instanceof Error) {
setLoginError(err.message);
} else {
setLoginError("Login failed. Please check your credentials.");
}
}
};
return (
<div className="relative flex min-h-svh items-center justify-center bg-background p-4 sm:p-6">
{/* Subtle radial glow */}
<div className="pointer-events-none fixed inset-0 overflow-hidden">
<div
className="absolute -top-1/2 left-1/2 size-[800px] -translate-x-1/2 rounded-full opacity-[0.03]"
style={{
background: "radial-gradient(circle, hsl(var(--primary)) 0%, transparent 70%)",
}}
/>
</div>
<div className="relative flex w-full max-w-md flex-col overflow-hidden rounded-xl border bg-card shadow-lg">
{/* Header */}
<div className="flex flex-col items-center px-6 pt-6 pb-2 text-center">
<Image
src="/icon-128.svg"
className="select-none dark:invert size-12 rounded-lg mb-3"
alt="SurfSense"
width={48}
height={48}
priority
/>
<h1 className="text-lg font-semibold tracking-tight">
Welcome to SurfSense Desktop
</h1>
<p className="mt-1 text-sm text-muted-foreground">
Configure shortcuts, then sign in to get started.
</p>
</div>
{/* Scrollable content */}
<div className="flex-1 overflow-y-auto px-6 py-4">
<div className="flex flex-col gap-5">
{/* ---- Shortcuts ---- */}
{shortcutsLoaded ? (
<div className="flex flex-col gap-2">
<p className="text-xs font-medium uppercase tracking-wider text-muted-foreground">
Keyboard Shortcuts
</p>
<div className="flex flex-col gap-1.5">
<ShortcutRecorder
value={shortcuts.generalAssist}
onChange={(accel) => updateShortcut("generalAssist", accel)}
onReset={() => resetShortcut("generalAssist")}
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
label="General Assist"
description="Launch SurfSense instantly from any application"
icon={Rocket}
/>
<ShortcutRecorder
value={shortcuts.quickAsk}
onChange={(accel) => updateShortcut("quickAsk", accel)}
onReset={() => resetShortcut("quickAsk")}
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
label="Quick Assist"
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
icon={Zap}
/>
<ShortcutRecorder
value={shortcuts.autocomplete}
onChange={(accel) => updateShortcut("autocomplete", accel)}
onReset={() => resetShortcut("autocomplete")}
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
label="Extreme Assist"
description="AI drafts text using your screen context and knowledge base"
icon={BrainCog}
/>
</div>
<p className="text-[11px] text-muted-foreground text-center mt-1">
Click a shortcut and press a new key combination to change it.
</p>
</div>
) : (
<div className="flex justify-center py-6">
<Spinner size="sm" />
</div>
)}
<Separator />
{/* ---- Auth ---- */}
<div className="flex flex-col gap-3">
<p className="text-xs font-medium uppercase tracking-wider text-muted-foreground">
Sign In
</p>
{isGoogleAuth ? (
<Button variant="outline" className="w-full gap-2 h-10" onClick={handleGoogleLogin}>
<IconBrandGoogleFilled className="size-4" />
Continue with Google
</Button>
) : (
<form onSubmit={handleLocalLogin} className="flex flex-col gap-3">
{loginError && (
<div className="rounded-md border border-destructive/20 bg-destructive/10 px-3 py-2 text-sm text-destructive">
{loginError}
</div>
)}
<div className="flex flex-col gap-1.5">
<Label htmlFor="email" className="text-xs">
Email
</Label>
<Input
id="email"
type="email"
placeholder="you@example.com"
required
value={email}
onChange={(e) => setEmail(e.target.value)}
disabled={isLoggingIn}
autoFocus
className="h-9"
/>
</div>
<div className="flex flex-col gap-1.5">
<Label htmlFor="password" className="text-xs">
Password
</Label>
<div className="relative">
<Input
id="password"
type={showPassword ? "text" : "password"}
placeholder="Enter your password"
required
value={password}
onChange={(e) => setPassword(e.target.value)}
disabled={isLoggingIn}
className="h-9 pr-9"
/>
<button
type="button"
onClick={() => setShowPassword((v) => !v)}
className="absolute inset-y-0 right-0 flex items-center pr-2.5 text-muted-foreground hover:text-foreground"
tabIndex={-1}
>
{showPassword ? (
<EyeOff className="size-3.5" />
) : (
<Eye className="size-3.5" />
)}
</button>
</div>
</div>
<Button type="submit" disabled={isLoggingIn} className="h-9 mt-1">
{isLoggingIn ? (
<>
<Spinner size="sm" className="text-primary-foreground" />
Signing in
</>
) : (
"Sign in"
)}
</Button>
</form>
)}
</div>
</div>
</div>
</div>
</div>
);
}

View file

@ -1,10 +1,11 @@
"use client";
import { useEffect, useState } from "react";
import { useRouter } from "next/navigation";
import { useEffect, useState } from "react";
import { Logo } from "@/components/Logo";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import { useElectronAPI } from "@/hooks/use-platform";
type PermissionStatus = "authorized" | "denied" | "not determined" | "restricted" | "limited";
@ -17,7 +18,8 @@ const STEPS = [
{
id: "screen-recording",
title: "Screen Recording",
description: "Lets SurfSense capture your screen to understand context and provide smart writing suggestions.",
description:
"Lets SurfSense capture your screen to understand context and provide smart writing suggestions.",
action: "requestScreenRecording",
field: "screenRecording" as const,
},
@ -57,19 +59,18 @@ function StatusBadge({ status }: { status: PermissionStatus }) {
export default function DesktopPermissionsPage() {
const router = useRouter();
const api = useElectronAPI();
const [permissions, setPermissions] = useState<PermissionsStatus | null>(null);
const [isElectron, setIsElectron] = useState(false);
useEffect(() => {
if (!window.electronAPI) return;
setIsElectron(true);
if (!api) return;
let interval: ReturnType<typeof setInterval> | null = null;
const isResolved = (s: string) => s === "authorized" || s === "restricted";
const poll = async () => {
const status = await window.electronAPI!.getPermissionsStatus();
const status = await api.getPermissionsStatus();
setPermissions(status);
if (isResolved(status.accessibility) && isResolved(status.screenRecording)) {
@ -79,10 +80,12 @@ export default function DesktopPermissionsPage() {
poll();
interval = setInterval(poll, 2000);
return () => { if (interval) clearInterval(interval); };
}, []);
return () => {
if (interval) clearInterval(interval);
};
}, [api]);
if (!isElectron) {
if (!api) {
return (
<div className="h-screen flex items-center justify-center bg-background">
<p className="text-muted-foreground">This page is only available in the desktop app.</p>
@ -98,19 +101,20 @@ export default function DesktopPermissionsPage() {
);
}
const allGranted = permissions.accessibility === "authorized" && permissions.screenRecording === "authorized";
const allGranted =
permissions.accessibility === "authorized" && permissions.screenRecording === "authorized";
const handleRequest = async (action: string) => {
if (action === "requestScreenRecording") {
await window.electronAPI!.requestScreenRecording();
await api.requestScreenRecording();
} else if (action === "requestAccessibility") {
await window.electronAPI!.requestAccessibility();
await api.requestAccessibility();
}
};
const handleContinue = () => {
if (allGranted) {
window.electronAPI!.restartApp();
api.restartApp();
}
};
@ -175,7 +179,8 @@ export default function DesktopPermissionsPage() {
</p>
)}
<p className="text-xs text-muted-foreground">
If SurfSense doesn&apos;t appear in the list, click <strong>+</strong> and select it from Applications.
If SurfSense doesn&apos;t appear in the list, click <strong>+</strong> and
select it from Applications.
</p>
</div>
)}
@ -201,6 +206,7 @@ export default function DesktopPermissionsPage() {
Grant permissions to continue
</Button>
<button
type="button"
onClick={handleSkip}
className="block mx-auto text-xs text-muted-foreground hover:text-foreground transition-colors"
>

View file

@ -4,10 +4,6 @@ export const metadata = {
title: "SurfSense Suggestion",
};
export default function SuggestionLayout({
children,
}: {
children: React.ReactNode;
}) {
export default function SuggestionLayout({ children }: { children: React.ReactNode }) {
return <div className="suggestion-body">{children}</div>;
}

View file

@ -1,7 +1,8 @@
"use client";
import { useCallback, useEffect, useRef, useState } from "react";
import { getBearerToken } from "@/lib/auth-utils";
import { useElectronAPI } from "@/hooks/use-platform";
import { ensureTokensFromElectron, getBearerToken } from "@/lib/auth-utils";
type SSEEvent =
| { type: "text-delta"; id: string; delta: string }
@ -9,7 +10,18 @@ type SSEEvent =
| { type: "text-end"; id: string }
| { type: "start"; messageId: string }
| { type: "finish" }
| { type: "error"; errorText: string };
| { type: "error"; errorText: string }
| {
type: "data-thinking-step";
data: { id: string; title: string; status: string; items: string[] };
};
interface AgentStep {
id: string;
title: string;
status: string;
items: string[];
}
function friendlyError(raw: string | number): string {
if (typeof raw === "number") {
@ -33,27 +45,52 @@ function friendlyError(raw: string | number): string {
const AUTO_DISMISS_MS = 3000;
function StepIcon({ status }: { status: string }) {
if (status === "complete") {
return (
<svg
className="step-icon step-icon-done"
viewBox="0 0 16 16"
fill="none"
aria-label="Step complete"
>
<circle cx="8" cy="8" r="7" stroke="#4ade80" strokeWidth="1.5" />
<path
d="M5 8.5l2 2 4-4.5"
stroke="#4ade80"
strokeWidth="1.5"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
}
return <span className="step-spinner" />;
}
export default function SuggestionPage() {
const api = useElectronAPI();
const [suggestion, setSuggestion] = useState("");
const [isLoading, setIsLoading] = useState(true);
const [isDesktop, setIsDesktop] = useState(true);
const [error, setError] = useState<string | null>(null);
const [steps, setSteps] = useState<AgentStep[]>([]);
const abortRef = useRef<AbortController | null>(null);
const isDesktop = !!api?.onAutocompleteContext;
useEffect(() => {
if (!window.electronAPI?.onAutocompleteContext) {
setIsDesktop(false);
if (!api?.onAutocompleteContext) {
setIsLoading(false);
}
}, []);
}, [api]);
useEffect(() => {
if (!error) return;
const timer = setTimeout(() => {
window.electronAPI?.dismissSuggestion?.();
api?.dismissSuggestion?.();
}, AUTO_DISMISS_MS);
return () => clearTimeout(timer);
}, [error]);
}, [error, api]);
const fetchSuggestion = useCallback(
async (screenshot: string, searchSpaceId: string, appName?: string, windowTitle?: string) => {
@ -64,35 +101,36 @@ export default function SuggestionPage() {
setIsLoading(true);
setSuggestion("");
setError(null);
setSteps([]);
const token = getBearerToken();
let token = getBearerToken();
if (!token) {
await ensureTokensFromElectron();
token = getBearerToken();
}
if (!token) {
setError(friendlyError("not authenticated"));
setIsLoading(false);
return;
}
const backendUrl =
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
try {
const response = await fetch(
`${backendUrl}/api/v1/autocomplete/vision/stream`,
{
method: "POST",
headers: {
Authorization: `Bearer ${token}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
screenshot,
search_space_id: parseInt(searchSpaceId, 10),
app_name: appName || "",
window_title: windowTitle || "",
}),
signal: controller.signal,
const response = await fetch(`${backendUrl}/api/v1/autocomplete/vision/stream`, {
method: "POST",
headers: {
Authorization: `Bearer ${token}`,
"Content-Type": "application/json",
},
);
body: JSON.stringify({
screenshot,
search_space_id: parseInt(searchSpaceId, 10),
app_name: appName || "",
window_title: windowTitle || "",
}),
signal: controller.signal,
});
if (!response.ok) {
setError(friendlyError(response.status));
@ -131,10 +169,19 @@ export default function SuggestionPage() {
setSuggestion((prev) => prev + parsed.delta);
} else if (parsed.type === "error") {
setError(friendlyError(parsed.errorText));
} else if (parsed.type === "data-thinking-step") {
const { id, title, status, items } = parsed.data;
setSteps((prev) => {
const existing = prev.findIndex((s) => s.id === id);
if (existing >= 0) {
const updated = [...prev];
updated[existing] = { id, title, status, items };
return updated;
}
return [...prev, { id, title, status, items }];
});
}
} catch {
continue;
}
} catch {}
}
}
}
@ -145,13 +192,13 @@ export default function SuggestionPage() {
setIsLoading(false);
}
},
[],
[]
);
useEffect(() => {
if (!window.electronAPI?.onAutocompleteContext) return;
if (!api?.onAutocompleteContext) return;
const cleanup = window.electronAPI.onAutocompleteContext((data) => {
const cleanup = api.onAutocompleteContext((data) => {
const searchSpaceId = data.searchSpaceId || "1";
if (data.screenshot) {
fetchSuggestion(data.screenshot, searchSpaceId, data.appName, data.windowTitle);
@ -159,7 +206,7 @@ export default function SuggestionPage() {
});
return cleanup;
}, [fetchSuggestion]);
}, [fetchSuggestion, api]);
if (!isDesktop) {
return (
@ -179,13 +226,33 @@ export default function SuggestionPage() {
);
}
if (isLoading && !suggestion) {
const showLoading = isLoading && !suggestion;
if (showLoading) {
return (
<div className="suggestion-tooltip">
<div className="suggestion-loading">
<span className="suggestion-dot" />
<span className="suggestion-dot" />
<span className="suggestion-dot" />
<div className="agent-activity">
{steps.length === 0 && (
<div className="activity-initial">
<span className="step-spinner" />
<span className="activity-label">Preparing</span>
</div>
)}
{steps.length > 0 && (
<div className="activity-steps">
{steps.map((step) => (
<div key={step.id} className="activity-step">
<StepIcon status={step.status} />
<span className="step-label">
{step.title}
{step.items.length > 0 && (
<span className="step-detail"> · {step.items[0]}</span>
)}
</span>
</div>
))}
</div>
)}
</div>
</div>
);
@ -193,12 +260,12 @@ export default function SuggestionPage() {
const handleAccept = () => {
if (suggestion) {
window.electronAPI?.acceptSuggestion?.(suggestion);
api?.acceptSuggestion?.(suggestion);
}
};
const handleDismiss = () => {
window.electronAPI?.dismissSuggestion?.();
api?.dismissSuggestion?.();
};
if (!suggestion) return null;
@ -207,10 +274,18 @@ export default function SuggestionPage() {
<div className="suggestion-tooltip">
<p className="suggestion-text">{suggestion}</p>
<div className="suggestion-actions">
<button className="suggestion-btn suggestion-btn-accept" onClick={handleAccept}>
<button
type="button"
className="suggestion-btn suggestion-btn-accept"
onClick={handleAccept}
>
Accept
</button>
<button className="suggestion-btn suggestion-btn-dismiss" onClick={handleDismiss}>
<button
type="button"
className="suggestion-btn suggestion-btn-dismiss"
onClick={handleDismiss}
>
Dismiss
</button>
</div>

View file

@ -1,121 +1,193 @@
html:has(.suggestion-body),
body:has(.suggestion-body) {
margin: 0 !important;
padding: 0 !important;
background: transparent !important;
overflow: hidden !important;
height: auto !important;
width: 100% !important;
margin: 0 !important;
padding: 0 !important;
background: transparent !important;
overflow: hidden !important;
height: auto !important;
width: 100% !important;
}
.suggestion-body {
margin: 0;
padding: 0;
background: transparent;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
-webkit-font-smoothing: antialiased;
user-select: none;
-webkit-app-region: no-drag;
margin: 0;
padding: 0;
background: transparent;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
-webkit-font-smoothing: antialiased;
user-select: none;
-webkit-app-region: no-drag;
}
.suggestion-tooltip {
background: #1e1e1e;
border: 1px solid #3c3c3c;
border-radius: 8px;
padding: 8px 12px;
margin: 4px;
max-width: 400px;
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.5);
box-sizing: border-box;
background: #1e1e1e;
border: 1px solid #3c3c3c;
border-radius: 8px;
padding: 8px 12px;
margin: 4px;
max-width: 400px;
/* MAX_HEIGHT in suggestion-window.ts is 400px. Subtract 8px for margin
(4px * 2) so the tooltip + margin fits within the Electron window.
box-sizing: border-box ensures padding + border are included. */
max-height: 392px;
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.5);
display: flex;
flex-direction: column;
overflow: hidden;
}
.suggestion-text {
color: #d4d4d4;
font-size: 13px;
line-height: 1.45;
margin: 0 0 6px 0;
word-wrap: break-word;
white-space: pre-wrap;
color: #d4d4d4;
font-size: 13px;
line-height: 1.45;
margin: 0 0 6px 0;
word-wrap: break-word;
white-space: pre-wrap;
overflow-y: auto;
flex: 1 1 auto;
min-height: 0;
}
.suggestion-text::-webkit-scrollbar {
width: 5px;
}
.suggestion-text::-webkit-scrollbar-track {
background: transparent;
}
.suggestion-text::-webkit-scrollbar-thumb {
background: #555;
border-radius: 3px;
}
.suggestion-text::-webkit-scrollbar-thumb:hover {
background: #777;
}
.suggestion-actions {
display: flex;
justify-content: flex-end;
gap: 4px;
border-top: 1px solid #2a2a2a;
padding-top: 6px;
display: flex;
justify-content: flex-end;
gap: 4px;
border-top: 1px solid #2a2a2a;
padding-top: 6px;
flex-shrink: 0;
}
.suggestion-btn {
padding: 2px 8px;
border-radius: 3px;
border: 1px solid #3c3c3c;
font-family: inherit;
font-size: 10px;
font-weight: 500;
cursor: pointer;
line-height: 16px;
transition: background 0.15s, border-color 0.15s;
padding: 2px 8px;
border-radius: 3px;
border: 1px solid #3c3c3c;
font-family: inherit;
font-size: 10px;
font-weight: 500;
cursor: pointer;
line-height: 16px;
transition:
background 0.15s,
border-color 0.15s;
}
.suggestion-btn-accept {
background: #2563eb;
border-color: #3b82f6;
color: #fff;
background: #2563eb;
border-color: #3b82f6;
color: #fff;
}
.suggestion-btn-accept:hover {
background: #1d4ed8;
background: #1d4ed8;
}
.suggestion-btn-dismiss {
background: #2a2a2a;
color: #999;
background: #2a2a2a;
color: #999;
}
.suggestion-btn-dismiss:hover {
background: #333;
color: #ccc;
background: #333;
color: #ccc;
}
.suggestion-error {
border-color: #5c2626;
border-color: #5c2626;
}
.suggestion-error-text {
color: #f48771;
font-size: 12px;
color: #f48771;
font-size: 12px;
}
.suggestion-loading {
display: flex;
gap: 5px;
padding: 2px 0;
justify-content: center;
/* --- Agent activity indicator --- */
.agent-activity {
display: flex;
flex-direction: column;
gap: 4px;
overflow-y: auto;
max-height: 340px;
}
.suggestion-dot {
width: 4px;
height: 4px;
border-radius: 50%;
background: #666;
animation: suggestion-pulse 1.2s infinite ease-in-out;
.activity-initial {
display: flex;
align-items: center;
gap: 8px;
padding: 2px 0;
}
.suggestion-dot:nth-child(2) {
animation-delay: 0.15s;
.activity-label {
color: #a1a1aa;
font-size: 12px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.suggestion-dot:nth-child(3) {
animation-delay: 0.3s;
.activity-steps {
display: flex;
flex-direction: column;
gap: 3px;
}
@keyframes suggestion-pulse {
0%, 80%, 100% {
opacity: 0.3;
transform: scale(0.8);
}
40% {
opacity: 1;
transform: scale(1.1);
}
.activity-step {
display: flex;
align-items: center;
gap: 6px;
min-height: 18px;
}
.step-label {
color: #d4d4d4;
font-size: 12px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.step-detail {
color: #71717a;
font-size: 11px;
}
/* Spinner (in_progress) */
.step-spinner {
width: 14px;
height: 14px;
flex-shrink: 0;
border: 1.5px solid #3f3f46;
border-top-color: #a78bfa;
border-radius: 50%;
animation: step-spin 0.7s linear infinite;
}
/* Checkmark icon (complete) */
.step-icon {
width: 14px;
height: 14px;
flex-shrink: 0;
}
@keyframes step-spin {
to {
transform: rotate(360deg);
}
}

View file

@ -10,6 +10,7 @@ import { ZeroProvider } from "@/components/providers/ZeroProvider";
import { ThemeProvider } from "@/components/theme/theme-provider";
import { Toaster } from "@/components/ui/sonner";
import { LocaleProvider } from "@/contexts/LocaleContext";
import { PlatformProvider } from "@/contexts/platform-context";
import { ReactQueryClientProvider } from "@/lib/query-client/query-client.provider";
import { cn } from "@/lib/utils";
@ -139,15 +140,17 @@ export default function RootLayout({
disableTransitionOnChange
defaultTheme="system"
>
<RootProvider>
<ReactQueryClientProvider>
<ZeroProvider>
<GlobalLoadingProvider>{children}</GlobalLoadingProvider>
</ZeroProvider>
</ReactQueryClientProvider>
<Toaster />
<AnnouncementToastProvider />
</RootProvider>
<PlatformProvider>
<RootProvider>
<ReactQueryClientProvider>
<ZeroProvider>
<GlobalLoadingProvider>{children}</GlobalLoadingProvider>
</ZeroProvider>
</ReactQueryClientProvider>
<Toaster />
<AnnouncementToastProvider />
</RootProvider>
</PlatformProvider>
</ThemeProvider>
</I18nProvider>
</LocaleProvider>

View file

@ -3,6 +3,7 @@
import { useEffect } from "react";
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
import { getAndClearRedirectPath, setBearerToken, setRefreshToken } from "@/lib/auth-utils";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { trackLoginSuccess } from "@/lib/posthog/events";
interface TokenHandlerProps {
@ -29,52 +30,54 @@ const TokenHandler = ({
useGlobalLoadingEffect(true);
useEffect(() => {
// Only run on client-side
if (typeof window === "undefined") return;
// Read tokens from URL at mount time — no subscription needed.
// TokenHandler only runs once after an auth redirect, so a stale read
// is impossible and useSearchParams() would add a pointless subscription.
// (Vercel Best Practice: rerender-defer-reads 5.2)
const params = new URLSearchParams(window.location.search);
const token = params.get(tokenParamName);
const refreshToken = params.get("refresh_token");
const run = async () => {
const params = new URLSearchParams(window.location.search);
const token = params.get(tokenParamName);
const refreshToken = params.get("refresh_token");
if (token) {
try {
// Track login success for OAuth flows (e.g., Google)
// Local login already tracks success before redirecting here
const alreadyTracked = sessionStorage.getItem("login_success_tracked");
if (!alreadyTracked) {
// This is an OAuth flow (Google login) - track success
trackLoginSuccess("google");
if (token) {
try {
const alreadyTracked = sessionStorage.getItem("login_success_tracked");
if (!alreadyTracked) {
trackLoginSuccess("google");
}
sessionStorage.removeItem("login_success_tracked");
localStorage.setItem(storageKey, token);
setBearerToken(token);
if (refreshToken) {
setRefreshToken(refreshToken);
}
// Auto-set active search space in desktop if not already set
if (window.electronAPI?.getActiveSearchSpace) {
try {
const stored = await window.electronAPI.getActiveSearchSpace();
if (!stored) {
const spaces = await searchSpacesApiService.getSearchSpaces();
if (spaces?.length) {
await window.electronAPI.setActiveSearchSpace?.(String(spaces[0].id));
}
}
} catch {
// non-critical
}
}
const savedRedirectPath = getAndClearRedirectPath();
const finalRedirectPath = savedRedirectPath || redirectPath;
window.location.href = finalRedirectPath;
} catch (error) {
console.error("Error storing token in localStorage:", error);
window.location.href = redirectPath;
}
// Clear the flag for future logins
sessionStorage.removeItem("login_success_tracked");
// Store access token in localStorage using both methods for compatibility
localStorage.setItem(storageKey, token);
setBearerToken(token);
// Store refresh token if provided
if (refreshToken) {
setRefreshToken(refreshToken);
}
// Check if there's a saved redirect path from before the auth flow
const savedRedirectPath = getAndClearRedirectPath();
// Use the saved path if available, otherwise use the default redirectPath
const finalRedirectPath = savedRedirectPath || redirectPath;
// Redirect to the appropriate path
window.location.href = finalRedirectPath;
} catch (error) {
console.error("Error storing token in localStorage:", error);
// Even if there's an error, try to redirect to the default path
window.location.href = redirectPath;
}
}
};
run();
}, [tokenParamName, storageKey, redirectPath]);
// Return null - the global provider handles the loading UI

View file

@ -15,7 +15,7 @@ import {
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Spinner } from "@/components/ui/spinner";
import { logout } from "@/lib/auth-utils";
import { getLoginPath, logout } from "@/lib/auth-utils";
import { resetUser, trackLogout } from "@/lib/posthog/events";
export function UserDropdown({
@ -33,22 +33,19 @@ export function UserDropdown({
if (isLoggingOut) return;
setIsLoggingOut(true);
try {
// Track logout event and reset PostHog identity
trackLogout();
resetUser();
// Revoke refresh token on server and clear all tokens from localStorage
await logout();
if (typeof window !== "undefined") {
window.location.href = "/";
window.location.href = getLoginPath();
}
} catch (error) {
console.error("Error during logout:", error);
// Even if there's an error, try to clear tokens and redirect
await logout();
if (typeof window !== "undefined") {
window.location.href = "/";
window.location.href = getLoginPath();
}
}
};

View file

@ -87,6 +87,7 @@ import {
} from "@/components/ui/drawer";
import { useComments } from "@/hooks/use-comments";
import { useMediaQuery } from "@/hooks/use-media-query";
import { useElectronAPI } from "@/hooks/use-platform";
import { cn } from "@/lib/utils";
// Dynamically import video presentation tool to avoid loading Babel and Remotion in main bundle
@ -463,16 +464,15 @@ export const AssistantMessage: FC = () => {
const AssistantActionBar: FC = () => {
const isLast = useAuiState((s) => s.message.isLast);
const aui = useAui();
const [quickAskMode, setQuickAskMode] = useState("");
const api = useElectronAPI();
const [isQuickAssist, setIsQuickAssist] = useState(false);
useEffect(() => {
if (!isLast || !window.electronAPI?.getQuickAskMode) return;
window.electronAPI.getQuickAskMode().then((mode) => {
if (mode) setQuickAskMode(mode);
if (!api?.getQuickAskMode) return;
api.getQuickAskMode().then((mode) => {
if (mode) setIsQuickAssist(true);
});
}, [isLast]);
const isTransform = isLast && !!window.electronAPI?.replaceText && quickAskMode === "transform";
}, [api]);
return (
<ActionBarPrimitive.Root
@ -482,7 +482,7 @@ const AssistantActionBar: FC = () => {
className="aui-assistant-action-bar-root -ml-1 col-start-3 row-start-2 flex gap-1 text-muted-foreground md:data-floating:absolute md:data-floating:rounded-md md:data-floating:p-1 [&>button]:opacity-100 md:[&>button]:opacity-[var(--aui-button-opacity,1)]"
>
<ActionBarPrimitive.Copy asChild>
<TooltipIconButton tooltip="Copy">
<TooltipIconButton tooltip="Copy to clipboard">
<AuiIf condition={({ message }) => message.isCopied}>
<CheckIcon />
</AuiIf>
@ -492,29 +492,27 @@ const AssistantActionBar: FC = () => {
</TooltipIconButton>
</ActionBarPrimitive.Copy>
<ActionBarPrimitive.ExportMarkdown asChild>
<TooltipIconButton tooltip="Download">
<TooltipIconButton tooltip="Download as Markdown">
<DownloadIcon />
</TooltipIconButton>
</ActionBarPrimitive.ExportMarkdown>
{isLast && (
<ActionBarPrimitive.Reload asChild>
<TooltipIconButton tooltip="Refresh">
<TooltipIconButton tooltip="Regenerate response">
<RefreshCwIcon />
</TooltipIconButton>
</ActionBarPrimitive.Reload>
)}
{isTransform && (
<button
type="button"
{isQuickAssist && (
<TooltipIconButton
tooltip="Paste back into source app"
onClick={() => {
const text = aui.message().getCopyText();
window.electronAPI?.replaceText(text);
api?.replaceText(text);
}}
className="ml-1 inline-flex items-center gap-1.5 rounded-md bg-primary px-3 py-1.5 text-xs font-medium text-primary-foreground transition-colors hover:bg-primary/90"
>
<ClipboardPaste className="size-3.5" />
Paste back
</button>
<ClipboardPaste />
</TooltipIconButton>
)}
</ActionBarPrimitive.Root>
);

View file

@ -216,7 +216,7 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
onPointerDownOutside={(e) => {
if (pickerOpen) e.preventDefault();
}}
className="max-w-3xl w-[95vw] sm:w-full h-[75vh] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border ring-0 dark:ring-0 bg-muted dark:bg-muted text-foreground [&>button]:right-4 sm:[&>button]:right-12 [&>button]:top-6 sm:[&>button]:top-10 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button_svg]:size-5 select-none"
className="max-w-3xl w-[95vw] sm:w-full h-[75vh] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border ring-0 dark:ring-0 bg-muted dark:bg-muted text-foreground [&>button]:right-4 sm:[&>button]:right-12 [&>button]:top-6 sm:[&>button]:top-10 [&>button]:opacity-80 [&>button]:hover:opacity-100 [&>button]:hover:bg-foreground/10 [&>button>svg]:size-5 select-none"
>
<DialogTitle className="sr-only">Manage Connectors</DialogTitle>
{/* YouTube Crawler View - shown when adding YouTube videos */}

View file

@ -144,18 +144,14 @@ export const ConnectorConnectView: FC<ConnectorConnectViewProps> = ({
type="button"
onClick={handleFormSubmit}
disabled={isSubmitting}
className="text-xs sm:text-sm min-w-[140px] disabled:opacity-50 disabled:cursor-not-allowed disabled:pointer-events-none"
className="relative text-xs sm:text-sm min-w-[140px] disabled:opacity-50 disabled:cursor-not-allowed disabled:pointer-events-none"
>
{isSubmitting ? (
<>
<Spinner size="sm" className="mr-2" />
Connecting
</>
) : connectorType === "MCP_CONNECTOR" ? (
"Connect"
) : (
`Connect ${getConnectorTypeDisplay(connectorType)}`
)}
<span className={isSubmitting ? "opacity-0" : ""}>
{connectorType === "MCP_CONNECTOR"
? "Connect"
: `Connect ${getConnectorTypeDisplay(connectorType)}`}
</span>
{isSubmitting && <Spinner size="sm" className="absolute" />}
</Button>
</div>
</div>

View file

@ -369,16 +369,10 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
size="sm"
onClick={handleDisconnectConfirm}
disabled={isDisconnecting}
className="text-xs sm:text-sm flex-1 sm:flex-initial h-10 sm:h-auto py-2 sm:py-2"
className="relative text-xs sm:text-sm flex-1 sm:flex-initial h-10 sm:h-auto py-2 sm:py-2"
>
{isDisconnecting ? (
<>
<Spinner size="sm" className="mr-2" />
Disconnecting
</>
) : (
"Confirm Disconnect"
)}
<span className={isDisconnecting ? "opacity-0" : ""}>Confirm Disconnect</span>
{isDisconnecting && <Spinner size="sm" className="absolute" />}
</Button>
<Button
variant="ghost"
@ -415,16 +409,10 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
<Button
onClick={onSave}
disabled={isSaving || isDisconnecting}
className="text-xs sm:text-sm flex-1 sm:flex-initial h-12 sm:h-auto py-3 sm:py-2"
className="relative text-xs sm:text-sm flex-1 sm:flex-initial h-12 sm:h-auto py-3 sm:py-2"
>
{isSaving ? (
<>
<Spinner size="sm" className="mr-2" />
Saving
</>
) : (
"Save Changes"
)}
<span className={isSaving ? "opacity-0" : ""}>Save Changes</span>
{isSaving && <Spinner size="sm" className="absolute" />}
</Button>
)}
</div>

View file

@ -1,6 +1,6 @@
"use client";
import { Cable } from "lucide-react";
import { Search, Unplug } from "lucide-react";
import type { FC } from "react";
import { getDocumentTypeLabel } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
import { Button } from "@/components/ui/button";
@ -134,9 +134,17 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
const hasActiveConnectors =
filteredOAuthConnectorTypes.length > 0 || filteredNonOAuthConnectors.length > 0;
const hasFilteredResults = hasActiveConnectors || standaloneDocuments.length > 0;
return (
<TabsContent value="active" className="m-0">
{hasSources ? (
{hasSources && !hasFilteredResults && searchQuery ? (
<div className="flex flex-col items-center justify-center py-20 text-center">
<Search className="size-8 text-muted-foreground mb-3" />
<p className="text-sm text-muted-foreground">No connectors found</p>
<p className="text-xs text-muted-foreground/60 mt-1">Try a different search term</p>
</div>
) : hasSources ? (
<div className="space-y-6">
{/* Active Connectors Section */}
{hasActiveConnectors && (
@ -302,7 +310,7 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
) : (
<div className="flex flex-col items-center justify-center py-20 text-center">
<div className="flex h-16 w-16 items-center justify-center rounded-full bg-muted mb-4">
<Cable className="size-8 text-muted-foreground" />
<Unplug className="size-8 text-muted-foreground" />
</div>
<h4 className="text-lg font-semibold">No active sources</h4>
<p className="text-sm text-muted-foreground mt-1 max-w-[280px]">

View file

@ -1,8 +1,10 @@
"use client";
import { Search } from "lucide-react";
import type { FC } from "react";
import { EnumConnectorName } from "@/contracts/enums/connector";
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
import { usePlatform } from "@/hooks/use-platform";
import { isSelfHosted } from "@/lib/env-config";
import { ConnectorCard } from "../components/connector-card";
import {
@ -74,9 +76,8 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
onManage,
onViewAccountsList,
}) => {
// Check if self-hosted mode (for showing self-hosted only connectors)
const selfHosted = isSelfHosted();
const isDesktop = typeof window !== "undefined" && !!window.electronAPI;
const { isDesktop } = usePlatform();
const matchesSearch = (title: string, description: string) =>
title.toLowerCase().includes(searchQuery.toLowerCase()) ||
@ -287,6 +288,18 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
moreIntegrationsOther.length > 0 ||
moreIntegrationsCrawlers.length > 0;
const hasAnyResults = hasDocumentFileConnectors || hasMoreIntegrations;
if (!hasAnyResults && searchQuery) {
return (
<div className="flex flex-col items-center justify-center py-20 text-center">
<Search className="size-8 text-muted-foreground mb-3" />
<p className="text-sm text-muted-foreground">No connectors found</p>
<p className="text-xs text-muted-foreground/60 mt-1">Try a different search term</p>
</div>
);
}
return (
<div className="space-y-8">
{/* Document/Files Connectors */}

View file

@ -173,9 +173,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
<Plus className="size-3 text-primary" />
)}
</div>
<span className="text-xs sm:text-sm font-medium">
{isConnecting ? "Connecting" : buttonText}
</span>
<span className="text-xs sm:text-sm font-medium">{buttonText}</span>
</button>
</div>
</div>

View file

@ -335,16 +335,10 @@ export const YouTubeCrawlerView: FC<YouTubeCrawlerViewProps> = ({ searchSpaceId,
<Button
onClick={handleSubmit}
disabled={isSubmitting || isFetchingPlaylist || videoTags.length === 0}
className="text-xs sm:text-sm min-w-[140px] disabled:opacity-50 disabled:cursor-not-allowed disabled:pointer-events-none"
className="relative text-xs sm:text-sm min-w-[140px] disabled:opacity-50 disabled:cursor-not-allowed disabled:pointer-events-none"
>
{isSubmitting ? (
<>
<Spinner size="sm" className="mr-2" />
{t("processing")}
</>
) : (
t("submit")
)}
<span className={isSubmitting ? "opacity-0" : ""}>{t("submit")}</span>
{isSubmitting && <Spinner size="sm" className="absolute" />}
</Button>
</div>
</div>

View file

@ -125,18 +125,16 @@ const DocumentUploadPopupContent: FC<{
onPointerDownOutside={(e) => e.preventDefault()}
onInteractOutside={(e) => e.preventDefault()}
onEscapeKeyDown={(e) => e.preventDefault()}
className="select-none max-w-2xl w-[95vw] sm:w-[640px] h-[min(440px,75dvh)] sm:h-[min(500px,80vh)] flex flex-col p-0 gap-0 overflow-hidden border border-border ring-0 bg-muted dark:bg-muted text-foreground [&>button]:right-3 sm:[&>button]:right-6 [&>button]:top-3 sm:[&>button]:top-5 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button]:z-[100] [&>button_svg]:size-4 sm:[&>button_svg]:size-5"
className="select-none max-w-2xl w-[95vw] sm:w-[640px] h-[min(440px,75dvh)] sm:h-[min(520px,80vh)] flex flex-col p-0 gap-0 overflow-hidden border border-border ring-0 bg-muted dark:bg-muted text-foreground [&>button]:right-3 sm:[&>button]:right-6 [&>button]:top-5 sm:[&>button]:top-8 [&>button]:opacity-80 [&>button]:hover:opacity-100 [&>button]:hover:bg-foreground/10 [&>button]:z-[100] [&>button>svg]:size-4 sm:[&>button>svg]:size-5"
>
<DialogTitle className="sr-only">Upload Document</DialogTitle>
<div className="flex-1 min-h-0 overflow-y-auto overscroll-contain">
<div className="sticky top-0 z-20 bg-muted px-4 sm:px-6 pt-4 sm:pt-5 pb-10">
<div className="sticky top-0 z-20 bg-muted px-4 sm:px-6 pt-6 sm:pt-8 pb-10">
<div className="flex items-center gap-2 mb-1 pr-8 sm:pr-0">
<h2 className="text-base sm:text-lg font-semibold tracking-tight">
Upload Documents
</h2>
<h2 className="text-xl sm:text-3xl font-semibold tracking-tight">Upload Documents</h2>
</div>
<p className="text-xs sm:text-sm text-muted-foreground line-clamp-1">
<p className="text-xs sm:text-base text-muted-foreground/80 line-clamp-1">
Upload and sync your documents to your search space
</p>
</div>

View file

@ -3,10 +3,10 @@
import type { ImageMessagePartComponent } from "@assistant-ui/react";
import { cva, type VariantProps } from "class-variance-authority";
import { ImageIcon, ImageOffIcon } from "lucide-react";
import NextImage from "next/image";
import { memo, type PropsWithChildren, useEffect, useRef, useState } from "react";
import { createPortal } from "react-dom";
import { cn } from "@/lib/utils";
import NextImage from 'next/image';
const imageVariants = cva("aui-image-root relative overflow-hidden rounded-lg", {
variants: {
@ -88,23 +88,23 @@ function ImagePreview({
<ImageOffIcon className="size-8 text-muted-foreground" />
</div>
) : isDataOrBlobUrl(src) ? (
// biome-ignore lint/performance/noImgElement: data/blob URLs need plain img
<img
ref={imgRef}
src={src}
alt={alt}
className={cn("block h-auto w-full object-contain", !loaded && "invisible", className)}
onLoad={(e) => {
if (typeof src === "string") setLoadedSrc(src);
onLoad?.(e);
}}
onError={(e) => {
if (typeof src === "string") setErrorSrc(src);
onError?.(e);
}}
{...props}
/>
) : (
// biome-ignore lint/performance/noImgElement: data/blob URLs need plain img
<img
ref={imgRef}
src={src}
alt={alt}
className={cn("block h-auto w-full object-contain", !loaded && "invisible", className)}
onLoad={(e) => {
if (typeof src === "string") setLoadedSrc(src);
onLoad?.(e);
}}
onError={(e) => {
if (typeof src === "string") setErrorSrc(src);
onError?.(e);
}}
{...props}
/>
) : (
// biome-ignore lint/performance/noImgElement: intentional for dynamic external URLs
// <img
// ref={imgRef}
@ -122,22 +122,22 @@ function ImagePreview({
// {...props}
// />
<NextImage
fill
src={src || ""}
alt={alt}
sizes="(max-width: 768px) 100vw, (max-width: 1200px) 80vw, 60vw"
className={cn("block object-contain", !loaded && "invisible", className)}
onLoad={() => {
if (typeof src === "string") setLoadedSrc(src);
onLoad?.();
}}
onError={() => {
if (typeof src === "string") setErrorSrc(src);
onError?.();
}}
unoptimized={false}
{...props}
/>
fill
src={src || ""}
alt={alt}
sizes="(max-width: 768px) 100vw, (max-width: 1200px) 80vw, 60vw"
className={cn("block object-contain", !loaded && "invisible", className)}
onLoad={() => {
if (typeof src === "string") setLoadedSrc(src);
onLoad?.();
}}
onError={() => {
if (typeof src === "string") setErrorSrc(src);
onError?.();
}}
unoptimized={false}
{...props}
/>
)}
</div>
);
@ -162,8 +162,8 @@ type ImageZoomProps = PropsWithChildren<{
alt?: string;
}>;
function isDataOrBlobUrl(src: string | undefined): boolean {
if (!src || typeof src !== "string") return false;
return src.startsWith("data:") || src.startsWith("blob:");
if (!src || typeof src !== "string") return false;
return src.startsWith("data:") || src.startsWith("blob:");
}
function ImageZoom({ src, alt = "Image preview", children }: ImageZoomProps) {
const [isMounted, setIsMounted] = useState(false);
@ -216,38 +216,38 @@ function ImageZoom({ src, alt = "Image preview", children }: ImageZoomProps) {
>
{/** biome-ignore lint/performance/noImgElement: <explanation> */}
{isDataOrBlobUrl(src) ? (
// biome-ignore lint/performance/noImgElement: data/blob URLs need plain img
<img
data-slot="image-zoom-content"
src={src}
alt={alt}
className="aui-image-zoom-content fade-in zoom-in-95 max-h-[90vh] max-w-[90vw] animate-in object-contain duration-200"
onClick={(e) => {
e.stopPropagation();
handleClose();
}}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.stopPropagation();
handleClose();
}
}}
/>
) : (
// biome-ignore lint/performance/noImgElement: data/blob URLs need plain img
<img
data-slot="image-zoom-content"
src={src}
alt={alt}
className="aui-image-zoom-content fade-in zoom-in-95 max-h-[90vh] max-w-[90vw] animate-in object-contain duration-200"
onClick={(e) => {
e.stopPropagation();
handleClose();
}}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.stopPropagation();
handleClose();
}
}}
/>
) : (
<NextImage
data-slot="image-zoom-content"
fill
src={src}
alt={alt}
sizes="90vw"
className="aui-image-zoom-content fade-in zoom-in-95 object-contain duration-200"
onClick={(e) => {
e.stopPropagation();
handleClose();
}}
unoptimized={false}
/>
)}
data-slot="image-zoom-content"
fill
src={src}
alt={alt}
sizes="90vw"
className="aui-image-zoom-content fade-in zoom-in-95 object-contain duration-200"
onClick={(e) => {
e.stopPropagation();
handleClose();
}}
unoptimized={false}
/>
)}
</button>,
document.body
)}

View file

@ -24,6 +24,7 @@ export interface MentionedDocument {
export interface InlineMentionEditorRef {
focus: () => void;
clear: () => void;
setText: (text: string) => void;
getText: () => string;
getMentionedDocuments: () => MentionedDocument[];
insertDocumentChip: (doc: Pick<Document, "id" | "title" | "document_type">) => void;
@ -397,6 +398,19 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
}
}, []);
// Replace editor content with plain text and place cursor at end
const setText = useCallback(
(text: string) => {
if (!editorRef.current) return;
editorRef.current.innerText = text;
const empty = text.length === 0;
setIsEmpty(empty);
onChange?.(text, Array.from(mentionedDocs.values()));
focusAtEnd();
},
[focusAtEnd, onChange, mentionedDocs]
);
const setDocumentChipStatus = useCallback(
(
docId: number,
@ -469,6 +483,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
useImperativeHandle(ref, () => ({
focus: () => editorRef.current?.focus(),
clear,
setText,
getText,
getMentionedDocuments,
insertDocumentChip,

View file

@ -241,9 +241,7 @@ const ThreadListItemComponent = memo(function ThreadListItemComponent({
<MessageSquareIcon className="size-4 shrink-0 text-muted-foreground" />
<div className="flex-1 min-w-0">
<p className="truncate text-sm font-medium">{thread.title || "New Chat"}</p>
<p className="truncate text-xs text-muted-foreground">
{relativeTime}
</p>
<p className="truncate text-xs text-muted-foreground">{relativeTime}</p>
</div>
<DropdownMenu>
<DropdownMenuTrigger asChild>

View file

@ -89,17 +89,10 @@ import type { Document } from "@/contracts/types/document.types";
import { useBatchCommentsPreload } from "@/hooks/use-comments";
import { useCommentsSync } from "@/hooks/use-comments-sync";
import { useMediaQuery } from "@/hooks/use-media-query";
import { useElectronAPI } from "@/hooks/use-platform";
import { cn } from "@/lib/utils";
/** Placeholder texts that cycle in new chats when input is empty */
const CYCLING_PLACEHOLDERS = [
"Ask SurfSense anything or @mention docs",
"Generate a podcast from my vacation ideas in Notion",
"Sum up last week's meeting notes from Drive in a bulleted list",
"Give me a brief overview of the most urgent tickets in Jira and Linear",
"Briefly, what are today's top ten important emails and calendar events?",
"Check if this week's Slack messages reference any GitHub issues",
];
const COMPOSER_PLACEHOLDER = "Ask anything · Type / for prompts · Type @ to mention docs";
export const Thread: FC = () => {
return <ThreadContent />;
@ -362,45 +355,23 @@ const Composer: FC = () => {
};
}, []);
const electronAPI = useElectronAPI();
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
const clipboardLoadedRef = useRef(false);
useEffect(() => {
if (!window.electronAPI || clipboardLoadedRef.current) return;
if (!electronAPI || clipboardLoadedRef.current) return;
clipboardLoadedRef.current = true;
window.electronAPI.getQuickAskText().then((text) => {
electronAPI.getQuickAskText().then((text) => {
if (text) {
setClipboardInitialText(text);
setShowPromptPicker(true);
}
});
}, []);
}, [electronAPI]);
const isThreadEmpty = useAuiState(({ thread }) => thread.isEmpty);
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
// Cycling placeholder state - only cycles in new chats
const [placeholderIndex, setPlaceholderIndex] = useState(0);
// Cycle through placeholders every 4 seconds when thread is empty (new chat)
useEffect(() => {
// Only cycle when thread is empty (new chat)
if (!isThreadEmpty) {
// Reset to first placeholder when chat becomes active
setPlaceholderIndex(0);
return;
}
const intervalId = setInterval(() => {
setPlaceholderIndex((prev) => (prev + 1) % CYCLING_PLACEHOLDERS.length);
}, 6000);
return () => clearInterval(intervalId);
}, [isThreadEmpty]);
// Compute current placeholder - only cycle in new chats
const currentPlaceholder = isThreadEmpty
? CYCLING_PLACEHOLDERS[placeholderIndex]
: CYCLING_PLACEHOLDERS[0];
const currentPlaceholder = COMPOSER_PLACEHOLDER;
// Live collaboration state
const { data: currentUser } = useAtomValue(currentUserAtom);
@ -504,34 +475,28 @@ const Composer: FC = () => {
: userText
? `${action.prompt}\n\n${userText}`
: action.prompt;
editorRef.current?.setText(finalPrompt);
aui.composer().setText(finalPrompt);
aui.composer().send();
editorRef.current?.clear();
setShowPromptPicker(false);
setActionQuery("");
setMentionedDocuments([]);
setSidebarDocs([]);
},
[actionQuery, aui, setMentionedDocuments, setSidebarDocs]
[actionQuery, aui]
);
const handleQuickAskSelect = useCallback(
(action: { name: string; prompt: string; mode: "transform" | "explore" }) => {
if (!clipboardInitialText) return;
window.electronAPI?.setQuickAskMode(action.mode);
electronAPI?.setQuickAskMode(action.mode);
const finalPrompt = action.prompt.includes("{selection}")
? action.prompt.replace("{selection}", () => clipboardInitialText)
: `${action.prompt}\n\n${clipboardInitialText}`;
editorRef.current?.setText(finalPrompt);
aui.composer().setText(finalPrompt);
aui.composer().send();
editorRef.current?.clear();
setShowPromptPicker(false);
setActionQuery("");
setClipboardInitialText(undefined);
setMentionedDocuments([]);
setSidebarDocs([]);
},
[clipboardInitialText, aui, setMentionedDocuments, setSidebarDocs]
[clipboardInitialText, electronAPI, aui]
);
// Keyboard navigation for document/action picker (arrow keys, Enter, Escape)

View file

@ -26,7 +26,8 @@ export const ToolFallback: ToolCallMessagePartComponent = ({
);
const serializedResult = useMemo(
() => (result !== undefined && typeof result !== "string" ? JSON.stringify(result, null, 2) : null),
() =>
result !== undefined && typeof result !== "string" ? JSON.stringify(result, null, 2) : null,
[result]
);

View file

@ -1,6 +1,6 @@
"use client";
import { ArrowUp, Send, X } from "lucide-react";
import { ArrowUp } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react";
import { Button } from "@/components/ui/button";
import { Popover, PopoverAnchor, PopoverContent } from "@/components/ui/popover";
@ -307,7 +307,6 @@ export function CommentComposer({
onClick={onCancel}
disabled={isSubmitting}
>
<X className="mr-1 size-4" />
Cancel
</Button>
)}
@ -318,14 +317,7 @@ export function CommentComposer({
disabled={!canSubmit}
className={cn(!canSubmit && "opacity-50", compact && "size-8 shrink-0 rounded-full")}
>
{compact ? (
<ArrowUp className="size-4" />
) : (
<>
<Send className="mr-1 size-4" />
{submitLabel}
</>
)}
{compact ? <ArrowUp className="size-4" /> : submitLabel}
</Button>
</div>
</div>

Some files were not shown because too many files have changed in this diff Show more