mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-29 10:56:24 +02:00
Compare commits
53 commits
eb5799336c
...
80f775581b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80f775581b | ||
|
|
518cacf56e | ||
|
|
be98b395b2 | ||
|
|
7c6e52a0a5 | ||
|
|
b74ac8a608 | ||
|
|
a4a4deeda0 | ||
|
|
27e9e8d873 | ||
|
|
e574b5ec4a | ||
|
|
a05bb4ae0c | ||
|
|
91ea293fa2 | ||
|
|
82b5c7f19e | ||
|
|
bb1dcd32b6 | ||
|
|
49441233e7 | ||
|
|
e920923fa4 | ||
|
|
056fc0e7ff | ||
|
|
8d810467dd | ||
|
|
0a26a6c5bb | ||
|
|
5803fe79da | ||
|
|
7f32dd068f | ||
|
|
1b87719a92 | ||
|
|
e4462292e4 | ||
|
|
aba5f6a124 | ||
|
|
a624c86b04 | ||
|
|
122be76133 | ||
|
|
3a1d700817 | ||
|
|
e7beeb2a36 | ||
|
|
f03bf05aaa | ||
|
|
0fb92b7c56 | ||
|
|
63a75052ca | ||
|
|
dc7047f64d | ||
|
|
47f4be08d9 | ||
|
|
caca491774 | ||
|
|
b5a15b7681 | ||
|
|
be622c417c | ||
|
|
be7e73e615 | ||
|
|
3251f0e98d | ||
|
|
8259fab254 | ||
|
|
02323e7b55 | ||
|
|
46c15c11da | ||
|
|
742548847a | ||
|
|
7fa1810d50 | ||
|
|
c9e5fe9cdb | ||
|
|
1f162f52c3 | ||
|
|
c6e94188eb | ||
|
|
f8913adaa3 | ||
|
|
87af012a60 | ||
|
|
8224360afa | ||
|
|
1248363ca9 | ||
|
|
f40de6b695 | ||
|
|
2824410be2 | ||
|
|
35582c9389 | ||
|
|
02fc6f1d16 | ||
|
|
5d22349dc1 |
142 changed files with 6229 additions and 2838 deletions
11
surfsense_backend/app/agents/autocomplete/__init__.py
Normal file
11
surfsense_backend/app/agents/autocomplete/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Agent-based vision autocomplete with scoped filesystem exploration."""
|
||||
|
||||
from app.agents.autocomplete.autocomplete_agent import (
|
||||
create_autocomplete_agent,
|
||||
stream_autocomplete_agent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_autocomplete_agent",
|
||||
"stream_autocomplete_agent",
|
||||
]
|
||||
442
surfsense_backend/app/agents/autocomplete/autocomplete_agent.py
Normal file
442
surfsense_backend/app/agents/autocomplete/autocomplete_agent.py
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
"""Vision autocomplete agent with scoped filesystem exploration.
|
||||
|
||||
Converts the stateless single-shot vision autocomplete into an agent that
|
||||
seeds a virtual filesystem from KB search results and lets the vision LLM
|
||||
explore documents via ``ls``, ``read_file``, ``glob``, ``grep``, etc.
|
||||
before generating the final completion.
|
||||
|
||||
Performance: KB search and agent graph compilation run in parallel so
|
||||
the only sequential latency is KB-search (or agent compile, whichever is
|
||||
slower) + the agent's LLM turns. There is no separate "query extraction"
|
||||
LLM call — the window title is used directly as the KB search query.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from deepagents.graph import BASE_AGENT_PROMPT
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from langchain.agents import create_agent
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
build_scoped_filesystem,
|
||||
search_knowledge_base,
|
||||
)
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KB_TOP_K = 10
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AUTOCOMPLETE_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
|
||||
|
||||
You will receive a screenshot of the user's screen. Your PRIMARY source of truth is the screenshot itself — the visual context determines what to write.
|
||||
|
||||
Your job:
|
||||
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
|
||||
2. Identify the text area where the user will type.
|
||||
3. Generate the text the user most likely wants to write based on the visual context.
|
||||
|
||||
You also have access to the user's knowledge base documents via filesystem tools. However:
|
||||
- ONLY consult the knowledge base if the screenshot clearly involves a topic where your KB documents are DIRECTLY relevant (e.g., the user is writing about a specific project/topic that matches a document title).
|
||||
- Do NOT explore documents just because they exist. Most autocomplete requests can be answered purely from the screenshot.
|
||||
- If you do read a document, only incorporate information that is 100% relevant to what the user is typing RIGHT NOW. Do not add extra details, background, or tangential information from the KB.
|
||||
- Keep your output SHORT — autocomplete should feel like a natural continuation, not an essay.
|
||||
|
||||
Key behavior:
|
||||
- If the text area is EMPTY, draft a concise response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
|
||||
- If the text area already has text, continue it naturally — typically just a sentence or two.
|
||||
|
||||
Rules:
|
||||
- Output ONLY the text to be inserted. No quotes, no explanations, no meta-commentary.
|
||||
- Be CONCISE. Prefer a single paragraph or a few sentences. Autocomplete is a quick assist, not a full draft.
|
||||
- Match the tone and formality of the surrounding context.
|
||||
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
|
||||
- Do NOT describe the screenshot or explain your reasoning.
|
||||
- Do NOT cite or reference documents explicitly — just let the knowledge inform your writing naturally.
|
||||
- If you cannot determine what to write, output nothing.
|
||||
|
||||
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`
|
||||
|
||||
All file paths must start with a `/`.
|
||||
- ls: list files and directories at a given path.
|
||||
- read_file: read a file from the filesystem.
|
||||
- write_file: create a temporary file in the session (not persisted).
|
||||
- edit_file: edit a file in the session (not persisted for /documents/ files).
|
||||
- glob: find files matching a pattern (e.g., "**/*.xml").
|
||||
- grep: search for text within files.
|
||||
|
||||
## When to Use Filesystem Tools
|
||||
|
||||
BEFORE reaching for any tool, ask yourself: "Can I write a good completion purely from the screenshot?" If yes, just write it — do NOT explore the KB.
|
||||
|
||||
Only use tools when:
|
||||
- The user is clearly writing about a specific topic that likely has detailed information in their KB.
|
||||
- You need a specific fact, name, number, or reference that the screenshot doesn't provide.
|
||||
|
||||
When you do use tools, be surgical:
|
||||
- Check the `ls` output first. If no document title looks relevant, stop — do not read files just to see what's there.
|
||||
- If a title looks relevant, read only the `<chunk_index>` (first ~20 lines) and jump to matched chunks. Do not read entire documents.
|
||||
- Extract only the specific information you need and move on to generating the completion.
|
||||
|
||||
## Reading Documents Efficiently
|
||||
|
||||
Documents are formatted as XML. Each document contains:
|
||||
- `<document_metadata>` — title, type, URL, etc.
|
||||
- `<chunk_index>` — a table of every chunk with its **line range** and a
|
||||
`matched="true"` flag for chunks that matched the search query.
|
||||
- `<document_content>` — the actual chunks in original document order.
|
||||
|
||||
**Workflow**: read the first ~20 lines to see the `<chunk_index>`, identify
|
||||
chunks marked `matched="true"`, then use `read_file(path, offset=<start_line>,
|
||||
limit=<lines>)` to jump directly to those sections."""
|
||||
|
||||
APP_CONTEXT_BLOCK = """
|
||||
|
||||
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
|
||||
|
||||
|
||||
def _build_autocomplete_system_prompt(app_name: str, window_title: str) -> str:
|
||||
prompt = AUTOCOMPLETE_SYSTEM_PROMPT
|
||||
if app_name:
|
||||
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
|
||||
return prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-compute KB filesystem (runs in parallel with agent compilation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _KBResult:
|
||||
"""Container for pre-computed KB filesystem results."""
|
||||
|
||||
__slots__ = ("files", "ls_ai_msg", "ls_tool_msg")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
files: dict[str, Any] | None = None,
|
||||
ls_ai_msg: AIMessage | None = None,
|
||||
ls_tool_msg: ToolMessage | None = None,
|
||||
) -> None:
|
||||
self.files = files
|
||||
self.ls_ai_msg = ls_ai_msg
|
||||
self.ls_tool_msg = ls_tool_msg
|
||||
|
||||
@property
|
||||
def has_documents(self) -> bool:
|
||||
return bool(self.files)
|
||||
|
||||
|
||||
async def precompute_kb_filesystem(
|
||||
search_space_id: int,
|
||||
query: str,
|
||||
top_k: int = KB_TOP_K,
|
||||
) -> _KBResult:
|
||||
"""Search the KB and build the scoped filesystem outside the agent.
|
||||
|
||||
This is designed to be called via ``asyncio.gather`` alongside agent
|
||||
graph compilation so the two run concurrently.
|
||||
"""
|
||||
if not query:
|
||||
return _KBResult()
|
||||
|
||||
try:
|
||||
search_results = await search_knowledge_base(
|
||||
query=query,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
if not search_results:
|
||||
return _KBResult()
|
||||
|
||||
new_files, _ = await build_scoped_filesystem(
|
||||
documents=search_results,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
if not new_files:
|
||||
return _KBResult()
|
||||
|
||||
doc_paths = [
|
||||
p
|
||||
for p, v in new_files.items()
|
||||
if p.startswith("/documents/") and v is not None
|
||||
]
|
||||
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
||||
ai_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}
|
||||
],
|
||||
)
|
||||
tool_msg = ToolMessage(
|
||||
content=str(doc_paths) if doc_paths else "No documents found.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
return _KBResult(files=new_files, ls_ai_msg=ai_msg, ls_tool_msg=tool_msg)
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"KB pre-computation failed, proceeding without KB", exc_info=True
|
||||
)
|
||||
return _KBResult()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filesystem middleware — no save_document, no persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
|
||||
"""Filesystem middleware for autocomplete — read-only exploration only.
|
||||
|
||||
Strips ``save_document`` (permanent KB persistence) and passes
|
||||
``search_space_id=None`` so ``write_file`` / ``edit_file`` stay ephemeral.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(search_space_id=None, created_by_id=None)
|
||||
self.tools = [t for t in self.tools if t.name != "save_document"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _compile_agent(
|
||||
llm: BaseChatModel,
|
||||
app_name: str,
|
||||
window_title: str,
|
||||
) -> Any:
|
||||
"""Compile the agent graph (CPU-bound, runs in a thread)."""
|
||||
system_prompt = _build_autocomplete_system_prompt(app_name, window_title)
|
||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||
|
||||
middleware = [
|
||||
AutocompleteFilesystemMiddleware(),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
||||
agent = await asyncio.to_thread(
|
||||
create_agent,
|
||||
llm,
|
||||
system_prompt=final_system_prompt,
|
||||
tools=[],
|
||||
middleware=middleware,
|
||||
)
|
||||
return agent.with_config({"recursion_limit": 200})
|
||||
|
||||
|
||||
async def create_autocomplete_agent(
|
||||
llm: BaseChatModel,
|
||||
*,
|
||||
search_space_id: int,
|
||||
kb_query: str,
|
||||
app_name: str = "",
|
||||
window_title: str = "",
|
||||
) -> tuple[Any, _KBResult]:
|
||||
"""Create the autocomplete agent and pre-compute KB in parallel.
|
||||
|
||||
Returns ``(agent, kb_result)`` so the caller can inject the pre-computed
|
||||
filesystem into the agent's initial state without any middleware delay.
|
||||
"""
|
||||
agent, kb = await asyncio.gather(
|
||||
_compile_agent(llm, app_name, window_title),
|
||||
precompute_kb_filesystem(search_space_id, kb_query),
|
||||
)
|
||||
return agent, kb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def stream_autocomplete_agent(
|
||||
agent: Any,
|
||||
input_data: dict[str, Any],
|
||||
streaming_service: VercelStreamingService,
|
||||
*,
|
||||
emit_message_start: bool = True,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream agent events as Vercel SSE, with thinking steps for tool calls.
|
||||
|
||||
When ``emit_message_start`` is False the caller has already sent the
|
||||
``message_start`` event (e.g. to show preparation steps before the agent
|
||||
runs).
|
||||
"""
|
||||
thread_id = uuid.uuid4().hex
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
current_text_id: str | None = None
|
||||
active_tool_depth = 0
|
||||
thinking_step_counter = 0
|
||||
tool_step_ids: dict[str, str] = {}
|
||||
step_titles: dict[str, str] = {}
|
||||
completed_step_ids: set[str] = set()
|
||||
last_active_step_id: str | None = None
|
||||
|
||||
def next_thinking_step_id() -> str:
|
||||
nonlocal thinking_step_counter
|
||||
thinking_step_counter += 1
|
||||
return f"autocomplete-step-{thinking_step_counter}"
|
||||
|
||||
def complete_current_step() -> str | None:
|
||||
nonlocal last_active_step_id
|
||||
if last_active_step_id and last_active_step_id not in completed_step_ids:
|
||||
completed_step_ids.add(last_active_step_id)
|
||||
title = step_titles.get(last_active_step_id, "Done")
|
||||
event = streaming_service.format_thinking_step(
|
||||
step_id=last_active_step_id,
|
||||
title=title,
|
||||
status="complete",
|
||||
)
|
||||
last_active_step_id = None
|
||||
return event
|
||||
return None
|
||||
|
||||
if emit_message_start:
|
||||
yield streaming_service.format_message_start()
|
||||
|
||||
# Emit an initial "Generating completion" step so the UI immediately
|
||||
# shows activity once the agent starts its first LLM call.
|
||||
gen_step_id = next_thinking_step_id()
|
||||
last_active_step_id = gen_step_id
|
||||
step_titles[gen_step_id] = "Generating completion"
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=gen_step_id,
|
||||
title="Generating completion",
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in agent.astream_events(
|
||||
input_data, config=config, version="v2"
|
||||
):
|
||||
event_type = event.get("event", "")
|
||||
|
||||
if event_type == "on_chat_model_stream":
|
||||
if active_tool_depth > 0:
|
||||
continue
|
||||
if "surfsense:internal" in event.get("tags", []):
|
||||
continue
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
content = chunk.content
|
||||
if content and isinstance(content, str):
|
||||
if current_text_id is None:
|
||||
step_event = complete_current_step()
|
||||
if step_event:
|
||||
yield step_event
|
||||
current_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(current_text_id)
|
||||
yield streaming_service.format_text_delta(
|
||||
current_text_id, content
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_start":
|
||||
active_tool_depth += 1
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
run_id = event.get("run_id", "")
|
||||
tool_input = event.get("data", {}).get("input", {})
|
||||
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
|
||||
step_event = complete_current_step()
|
||||
if step_event:
|
||||
yield step_event
|
||||
|
||||
tool_step_id = next_thinking_step_id()
|
||||
tool_step_ids[run_id] = tool_step_id
|
||||
last_active_step_id = tool_step_id
|
||||
|
||||
title, items = _describe_tool_call(tool_name, tool_input)
|
||||
step_titles[tool_step_id] = title
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title=title,
|
||||
status="in_progress",
|
||||
items=items,
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_end":
|
||||
active_tool_depth = max(0, active_tool_depth - 1)
|
||||
run_id = event.get("run_id", "")
|
||||
step_id = tool_step_ids.pop(run_id, None)
|
||||
if step_id and step_id not in completed_step_ids:
|
||||
completed_step_ids.add(step_id)
|
||||
title = step_titles.get(step_id, "Done")
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=step_id,
|
||||
title=title,
|
||||
status="complete",
|
||||
)
|
||||
if last_active_step_id == step_id:
|
||||
last_active_step_id = None
|
||||
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
step_event = complete_current_step()
|
||||
if step_event:
|
||||
yield step_event
|
||||
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Autocomplete agent streaming error: {e}", exc_info=True)
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
yield streaming_service.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming_service.format_done()
|
||||
|
||||
|
||||
def _describe_tool_call(tool_name: str, tool_input: Any) -> tuple[str, list[str]]:
|
||||
"""Return a human-readable (title, items) for a tool call thinking step."""
|
||||
inp = tool_input if isinstance(tool_input, dict) else {}
|
||||
if tool_name == "ls":
|
||||
path = inp.get("path", "/")
|
||||
return "Listing files", [path]
|
||||
if tool_name == "read_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Reading file", [display]
|
||||
if tool_name == "write_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Writing file", [display]
|
||||
if tool_name == "edit_file":
|
||||
fp = inp.get("file_path", "")
|
||||
display = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
return "Editing file", [display]
|
||||
if tool_name == "glob":
|
||||
pat = inp.get("pattern", "")
|
||||
base = inp.get("path", "/")
|
||||
return "Searching files", [f"{pat} in {base}"]
|
||||
if tool_name == "grep":
|
||||
pat = inp.get("pattern", "")
|
||||
path = inp.get("path", "")
|
||||
display_pat = pat[:60] + ("…" if len(pat) > 60 else "")
|
||||
return "Searching content", [
|
||||
f'"{display_pat}"' + (f" in {path}" if path else "")
|
||||
]
|
||||
return f"Using {tool_name}", []
|
||||
|
|
@ -225,6 +225,55 @@ class DropboxClient:
|
|||
|
||||
return all_items, None
|
||||
|
||||
async def get_latest_cursor(self, path: str = "") -> tuple[str | None, str | None]:
|
||||
"""Get a cursor representing the current state of a folder.
|
||||
|
||||
Uses /2/files/list_folder/get_latest_cursor so we can later call
|
||||
get_changes to receive only incremental updates.
|
||||
"""
|
||||
resp = await self._request(
|
||||
"/2/files/list_folder/get_latest_cursor",
|
||||
{"path": path, "recursive": False, "include_non_downloadable_files": True},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return None, f"Failed to get cursor: {resp.status_code} - {resp.text}"
|
||||
return resp.json().get("cursor"), None
|
||||
|
||||
async def get_changes(
|
||||
self, cursor: str
|
||||
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||
"""Fetch incremental changes since the given cursor.
|
||||
|
||||
Calls /2/files/list_folder/continue and handles pagination.
|
||||
Returns (entries, new_cursor, error).
|
||||
"""
|
||||
all_entries: list[dict[str, Any]] = []
|
||||
|
||||
resp = await self._request("/2/files/list_folder/continue", {"cursor": cursor})
|
||||
if resp.status_code == 401:
|
||||
return [], None, "Dropbox authentication expired (401)"
|
||||
if resp.status_code != 200:
|
||||
return [], None, f"Failed to get changes: {resp.status_code} - {resp.text}"
|
||||
|
||||
data = resp.json()
|
||||
all_entries.extend(data.get("entries", []))
|
||||
|
||||
while data.get("has_more"):
|
||||
cursor = data["cursor"]
|
||||
resp = await self._request(
|
||||
"/2/files/list_folder/continue", {"cursor": cursor}
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return (
|
||||
all_entries,
|
||||
data.get("cursor"),
|
||||
f"Pagination failed: {resp.status_code}",
|
||||
)
|
||||
data = resp.json()
|
||||
all_entries.extend(data.get("entries", []))
|
||||
|
||||
return all_entries, data.get("cursor"), None
|
||||
|
||||
async def get_metadata(self, path: str) -> tuple[dict[str, Any] | None, str | None]:
|
||||
resp = await self._request("/2/files/get_metadata", {"path": path})
|
||||
if resp.status_code != 200:
|
||||
|
|
|
|||
|
|
@ -53,7 +53,8 @@ async def download_and_extract_content(
|
|||
file_name = file.get("name", "Unknown")
|
||||
file_id = file.get("id", "")
|
||||
|
||||
if should_skip_file(file):
|
||||
skip, _unsup_ext = should_skip_file(file)
|
||||
if skip:
|
||||
return None, {}, "Skipping non-indexable item"
|
||||
|
||||
logger.info(f"Downloading file for content extraction: {file_name}")
|
||||
|
|
@ -87,9 +88,13 @@ async def download_and_extract_content(
|
|||
if error:
|
||||
return None, metadata, error
|
||||
|
||||
from app.connectors.onedrive.content_extractor import _parse_file_to_markdown
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=temp_file_path, filename=file_name)
|
||||
)
|
||||
markdown = result.markdown_content
|
||||
return markdown, metadata, None
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""File type handlers for Dropbox."""
|
||||
|
||||
PAPER_EXTENSION = ".paper"
|
||||
from app.etl_pipeline.file_classifier import should_skip_for_service
|
||||
|
||||
SKIP_EXTENSIONS: frozenset[str] = frozenset()
|
||||
PAPER_EXTENSION = ".paper"
|
||||
|
||||
MIME_TO_EXTENSION: dict[str, str] = {
|
||||
"application/pdf": ".pdf",
|
||||
|
|
@ -42,17 +42,25 @@ def is_paper_file(item: dict) -> bool:
|
|||
return ext == PAPER_EXTENSION
|
||||
|
||||
|
||||
def should_skip_file(item: dict) -> bool:
|
||||
def should_skip_file(item: dict) -> tuple[bool, str | None]:
|
||||
"""Skip folders and truly non-indexable files.
|
||||
|
||||
Paper docs are non-downloadable but exportable, so they are NOT skipped.
|
||||
Returns (should_skip, unsupported_extension_or_None).
|
||||
"""
|
||||
if is_folder(item):
|
||||
return True
|
||||
return True, None
|
||||
if is_paper_file(item):
|
||||
return False
|
||||
return False, None
|
||||
if not item.get("is_downloadable", True):
|
||||
return True
|
||||
return True, None
|
||||
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
name = item.get("name", "")
|
||||
ext = get_extension_from_name(name).lower()
|
||||
return ext in SKIP_EXTENSIONS
|
||||
if should_skip_for_service(name, app_config.ETL_SERVICE):
|
||||
ext = PurePosixPath(name).suffix.lower()
|
||||
return True, ext
|
||||
return False, None
|
||||
|
|
|
|||
|
|
@ -64,8 +64,10 @@ async def get_files_in_folder(
|
|||
)
|
||||
continue
|
||||
files.extend(sub_files)
|
||||
elif not should_skip_file(item):
|
||||
files.append(item)
|
||||
else:
|
||||
skip, _unsup_ext = should_skip_file(item)
|
||||
if not skip:
|
||||
files.append(item)
|
||||
|
||||
return files, None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
"""Content extraction for Google Drive files."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -20,6 +17,7 @@ from .file_types import (
|
|||
get_export_mime_type,
|
||||
get_extension_from_mime,
|
||||
is_google_workspace_file,
|
||||
should_skip_by_extension,
|
||||
should_skip_file,
|
||||
)
|
||||
|
||||
|
|
@ -45,6 +43,11 @@ async def download_and_extract_content(
|
|||
if should_skip_file(mime_type):
|
||||
return None, {}, f"Skipping {mime_type}"
|
||||
|
||||
if not is_google_workspace_file(mime_type):
|
||||
ext_skip, _unsup_ext = should_skip_by_extension(file_name)
|
||||
if ext_skip:
|
||||
return None, {}, f"Skipping unsupported extension: {file_name}"
|
||||
|
||||
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
|
||||
|
||||
drive_metadata: dict[str, Any] = {
|
||||
|
|
@ -97,7 +100,10 @@ async def download_and_extract_content(
|
|||
if error:
|
||||
return None, drive_metadata, error
|
||||
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
|
||||
etl_filename = (
|
||||
file_name + extension if is_google_workspace_file(mime_type) else file_name
|
||||
)
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, etl_filename)
|
||||
return markdown, drive_metadata, None
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -110,99 +116,14 @@ async def download_and_extract_content(
|
|||
|
||||
|
||||
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||
"""Parse a local file to markdown using the configured ETL service."""
|
||||
lower = filename.lower()
|
||||
"""Parse a local file to markdown using the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
if lower.endswith((".md", ".markdown", ".txt")):
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
||||
from litellm import atranscription
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||
logger.info(
|
||||
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
text = result.get("text", "")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
resp = await atranscription(**kwargs)
|
||||
text = resp.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
# Document files -- use configured ETL service
|
||||
from app.config import config as app_config
|
||||
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return await convert_document_to_markdown(docs)
|
||||
|
||||
if app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from app.tasks.document_processors.file_processors import (
|
||||
parse_with_llamacloud_retry,
|
||||
)
|
||||
|
||||
result = await parse_with_llamacloud_retry(
|
||||
file_path=file_path, estimated_pages=50
|
||||
)
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
||||
return markdown_documents[0].text
|
||||
|
||||
if app_config.ETL_SERVICE == "DOCLING":
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[docling] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(converter.convert, file_path)
|
||||
logger.info(
|
||||
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
)
|
||||
return result.markdown_content
|
||||
|
||||
|
||||
async def download_and_process_file(
|
||||
|
|
@ -236,10 +157,14 @@ async def download_and_process_file(
|
|||
file_name = file.get("name", "Unknown")
|
||||
mime_type = file.get("mimeType", "")
|
||||
|
||||
# Skip folders and shortcuts
|
||||
if should_skip_file(mime_type):
|
||||
return None, f"Skipping {mime_type}", None
|
||||
|
||||
if not is_google_workspace_file(mime_type):
|
||||
ext_skip, _unsup_ext = should_skip_by_extension(file_name)
|
||||
if ext_skip:
|
||||
return None, f"Skipping unsupported extension: {file_name}", None
|
||||
|
||||
logger.info(f"Downloading file: {file_name} ({mime_type})")
|
||||
|
||||
temp_file_path = None
|
||||
|
|
@ -310,10 +235,13 @@ async def download_and_process_file(
|
|||
"."
|
||||
)[-1]
|
||||
|
||||
etl_filename = (
|
||||
file_name + extension if is_google_workspace_file(mime_type) else file_name
|
||||
)
|
||||
logger.info(f"Processing {file_name} with Surfsense's file processor")
|
||||
await process_file_in_background(
|
||||
file_path=temp_file_path,
|
||||
filename=file_name,
|
||||
filename=etl_filename,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""File type handlers for Google Drive."""
|
||||
|
||||
from app.etl_pipeline.file_classifier import should_skip_for_service
|
||||
|
||||
GOOGLE_DOC = "application/vnd.google-apps.document"
|
||||
GOOGLE_SHEET = "application/vnd.google-apps.spreadsheet"
|
||||
GOOGLE_SLIDE = "application/vnd.google-apps.presentation"
|
||||
|
|
@ -46,6 +48,21 @@ def should_skip_file(mime_type: str) -> bool:
|
|||
return mime_type in [GOOGLE_FOLDER, GOOGLE_SHORTCUT]
|
||||
|
||||
|
||||
def should_skip_by_extension(filename: str) -> tuple[bool, str | None]:
|
||||
"""Check if the file extension is not parseable by the configured ETL service.
|
||||
|
||||
Returns (should_skip, unsupported_extension_or_None).
|
||||
"""
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
if should_skip_for_service(filename, app_config.ETL_SERVICE):
|
||||
ext = PurePosixPath(filename).suffix.lower()
|
||||
return True, ext
|
||||
return False, None
|
||||
|
||||
|
||||
def get_export_mime_type(mime_type: str) -> str | None:
|
||||
"""Get export MIME type for Google Workspace files."""
|
||||
return EXPORT_FORMATS.get(mime_type)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,9 @@
|
|||
"""Content extraction for OneDrive files.
|
||||
"""Content extraction for OneDrive files."""
|
||||
|
||||
Reuses the same ETL parsing logic as Google Drive since file parsing is
|
||||
extension-based, not provider-specific.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -31,7 +24,8 @@ async def download_and_extract_content(
|
|||
item_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if should_skip_file(file):
|
||||
skip, _unsup_ext = should_skip_file(file)
|
||||
if skip:
|
||||
return None, {}, "Skipping non-indexable item"
|
||||
|
||||
file_info = file.get("file", {})
|
||||
|
|
@ -84,98 +78,11 @@ async def download_and_extract_content(
|
|||
|
||||
|
||||
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||
"""Parse a local file to markdown using the configured ETL service.
|
||||
"""Parse a local file to markdown using the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
Same logic as Google Drive -- file parsing is extension-based.
|
||||
"""
|
||||
lower = filename.lower()
|
||||
|
||||
if lower.endswith((".md", ".markdown", ".txt")):
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
||||
from litellm import atranscription
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||
logger.info(
|
||||
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
text = result.get("text", "")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
resp = await atranscription(**kwargs)
|
||||
text = resp.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return await convert_document_to_markdown(docs)
|
||||
|
||||
if app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from app.tasks.document_processors.file_processors import (
|
||||
parse_with_llamacloud_retry,
|
||||
)
|
||||
|
||||
result = await parse_with_llamacloud_retry(
|
||||
file_path=file_path, estimated_pages=50
|
||||
)
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
||||
return markdown_documents[0].text
|
||||
|
||||
if app_config.ETL_SERVICE == "DOCLING":
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
t0 = time.monotonic()
|
||||
logger.info(
|
||||
f"[docling] START file={filename} thread={threading.current_thread().name}"
|
||||
)
|
||||
result = await asyncio.to_thread(converter.convert, file_path)
|
||||
logger.info(
|
||||
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||
)
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
)
|
||||
return result.markdown_content
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""File type handlers for Microsoft OneDrive."""
|
||||
|
||||
from app.etl_pipeline.file_classifier import should_skip_for_service
|
||||
|
||||
ONEDRIVE_FOLDER_FACET = "folder"
|
||||
ONENOTE_MIME = "application/msonenote"
|
||||
|
||||
|
|
@ -38,13 +40,28 @@ def is_folder(item: dict) -> bool:
|
|||
return ONEDRIVE_FOLDER_FACET in item
|
||||
|
||||
|
||||
def should_skip_file(item: dict) -> bool:
|
||||
"""Skip folders, OneNote files, remote items (shared links), and packages."""
|
||||
def should_skip_file(item: dict) -> tuple[bool, str | None]:
|
||||
"""Skip folders, OneNote files, remote items, packages, and unsupported extensions.
|
||||
|
||||
Returns (should_skip, unsupported_extension_or_None).
|
||||
The second element is only set when the skip is due to an unsupported extension.
|
||||
"""
|
||||
if is_folder(item):
|
||||
return True
|
||||
return True, None
|
||||
if "remoteItem" in item:
|
||||
return True
|
||||
return True, None
|
||||
if "package" in item:
|
||||
return True
|
||||
return True, None
|
||||
mime = item.get("file", {}).get("mimeType", "")
|
||||
return mime in SKIP_MIME_TYPES
|
||||
if mime in SKIP_MIME_TYPES:
|
||||
return True, None
|
||||
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
name = item.get("name", "")
|
||||
if should_skip_for_service(name, app_config.ETL_SERVICE):
|
||||
ext = PurePosixPath(name).suffix.lower()
|
||||
return True, ext
|
||||
return False, None
|
||||
|
|
|
|||
|
|
@ -71,8 +71,10 @@ async def get_files_in_folder(
|
|||
)
|
||||
continue
|
||||
files.extend(sub_files)
|
||||
elif not should_skip_file(item):
|
||||
files.append(item)
|
||||
else:
|
||||
skip, _unsup_ext = should_skip_file(item)
|
||||
if not skip:
|
||||
files.append(item)
|
||||
|
||||
return files, None
|
||||
|
||||
|
|
|
|||
0
surfsense_backend/app/etl_pipeline/__init__.py
Normal file
0
surfsense_backend/app/etl_pipeline/__init__.py
Normal file
39
surfsense_backend/app/etl_pipeline/constants.py
Normal file
39
surfsense_backend/app/etl_pipeline/constants.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import ssl
|
||||
|
||||
import httpx
|
||||
|
||||
LLAMACLOUD_MAX_RETRIES = 5
|
||||
LLAMACLOUD_BASE_DELAY = 10
|
||||
LLAMACLOUD_MAX_DELAY = 120
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
|
||||
ssl.SSLError,
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteError,
|
||||
httpx.WriteTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.LocalProtocolError,
|
||||
ConnectionError,
|
||||
ConnectionResetError,
|
||||
TimeoutError,
|
||||
OSError,
|
||||
)
|
||||
|
||||
UPLOAD_BYTES_PER_SECOND_SLOW = 100 * 1024
|
||||
MIN_UPLOAD_TIMEOUT = 120
|
||||
MAX_UPLOAD_TIMEOUT = 1800
|
||||
BASE_JOB_TIMEOUT = 600
|
||||
PER_PAGE_JOB_TIMEOUT = 60
|
||||
|
||||
|
||||
def calculate_upload_timeout(file_size_bytes: int) -> float:
|
||||
estimated_time = (file_size_bytes / UPLOAD_BYTES_PER_SECOND_SLOW) * 1.5
|
||||
return max(MIN_UPLOAD_TIMEOUT, min(estimated_time, MAX_UPLOAD_TIMEOUT))
|
||||
|
||||
|
||||
def calculate_job_timeout(estimated_pages: int, file_size_bytes: int) -> float:
|
||||
page_based_timeout = BASE_JOB_TIMEOUT + (estimated_pages * PER_PAGE_JOB_TIMEOUT)
|
||||
size_based_timeout = BASE_JOB_TIMEOUT + (file_size_bytes / (10 * 1024 * 1024)) * 60
|
||||
return max(page_based_timeout, size_based_timeout)
|
||||
21
surfsense_backend/app/etl_pipeline/etl_document.py
Normal file
21
surfsense_backend/app/etl_pipeline/etl_document.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class EtlRequest(BaseModel):
|
||||
file_path: str
|
||||
filename: str
|
||||
estimated_pages: int = 0
|
||||
|
||||
@field_validator("filename")
|
||||
@classmethod
|
||||
def filename_must_not_be_empty(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("filename must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
class EtlResult(BaseModel):
|
||||
markdown_content: str
|
||||
etl_service: str
|
||||
actual_pages: int = 0
|
||||
content_type: str
|
||||
90
surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
Normal file
90
surfsense_backend/app/etl_pipeline/etl_pipeline_service.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from app.config import config as app_config
|
||||
from app.etl_pipeline.etl_document import EtlRequest, EtlResult
|
||||
from app.etl_pipeline.exceptions import (
|
||||
EtlServiceUnavailableError,
|
||||
EtlUnsupportedFileError,
|
||||
)
|
||||
from app.etl_pipeline.file_classifier import FileCategory, classify_file
|
||||
from app.etl_pipeline.parsers.audio import transcribe_audio
|
||||
from app.etl_pipeline.parsers.direct_convert import convert_file_directly
|
||||
from app.etl_pipeline.parsers.plaintext import read_plaintext
|
||||
|
||||
|
||||
class EtlPipelineService:
|
||||
"""Single pipeline for extracting markdown from files. All callers use this."""
|
||||
|
||||
async def extract(self, request: EtlRequest) -> EtlResult:
|
||||
category = classify_file(request.filename)
|
||||
|
||||
if category == FileCategory.UNSUPPORTED:
|
||||
raise EtlUnsupportedFileError(
|
||||
f"File type not supported for parsing: {request.filename}"
|
||||
)
|
||||
|
||||
if category == FileCategory.PLAINTEXT:
|
||||
content = read_plaintext(request.file_path)
|
||||
return EtlResult(
|
||||
markdown_content=content,
|
||||
etl_service="PLAINTEXT",
|
||||
content_type="plaintext",
|
||||
)
|
||||
|
||||
if category == FileCategory.DIRECT_CONVERT:
|
||||
content = convert_file_directly(request.file_path, request.filename)
|
||||
return EtlResult(
|
||||
markdown_content=content,
|
||||
etl_service="DIRECT_CONVERT",
|
||||
content_type="direct_convert",
|
||||
)
|
||||
|
||||
if category == FileCategory.AUDIO:
|
||||
content = await transcribe_audio(request.file_path, request.filename)
|
||||
return EtlResult(
|
||||
markdown_content=content,
|
||||
etl_service="AUDIO",
|
||||
content_type="audio",
|
||||
)
|
||||
|
||||
return await self._extract_document(request)
|
||||
|
||||
async def _extract_document(self, request: EtlRequest) -> EtlResult:
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.utils.file_extensions import get_document_extensions_for_service
|
||||
|
||||
etl_service = app_config.ETL_SERVICE
|
||||
if not etl_service:
|
||||
raise EtlServiceUnavailableError(
|
||||
"No ETL_SERVICE configured. "
|
||||
"Set ETL_SERVICE to UNSTRUCTURED, LLAMACLOUD, or DOCLING in your .env"
|
||||
)
|
||||
|
||||
ext = PurePosixPath(request.filename).suffix.lower()
|
||||
supported = get_document_extensions_for_service(etl_service)
|
||||
if ext not in supported:
|
||||
raise EtlUnsupportedFileError(
|
||||
f"File type {ext} is not supported by {etl_service}"
|
||||
)
|
||||
|
||||
if etl_service == "DOCLING":
|
||||
from app.etl_pipeline.parsers.docling import parse_with_docling
|
||||
|
||||
content = await parse_with_docling(request.file_path, request.filename)
|
||||
elif etl_service == "UNSTRUCTURED":
|
||||
from app.etl_pipeline.parsers.unstructured import parse_with_unstructured
|
||||
|
||||
content = await parse_with_unstructured(request.file_path)
|
||||
elif etl_service == "LLAMACLOUD":
|
||||
from app.etl_pipeline.parsers.llamacloud import parse_with_llamacloud
|
||||
|
||||
content = await parse_with_llamacloud(
|
||||
request.file_path, request.estimated_pages
|
||||
)
|
||||
else:
|
||||
raise EtlServiceUnavailableError(f"Unknown ETL_SERVICE: {etl_service}")
|
||||
|
||||
return EtlResult(
|
||||
markdown_content=content,
|
||||
etl_service=etl_service,
|
||||
content_type="document",
|
||||
)
|
||||
10
surfsense_backend/app/etl_pipeline/exceptions.py
Normal file
10
surfsense_backend/app/etl_pipeline/exceptions.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
class EtlParseError(Exception):
|
||||
"""Raised when an ETL parser fails to produce content."""
|
||||
|
||||
|
||||
class EtlServiceUnavailableError(Exception):
|
||||
"""Raised when the configured ETL_SERVICE is not recognised."""
|
||||
|
||||
|
||||
class EtlUnsupportedFileError(Exception):
|
||||
"""Raised when a file type cannot be parsed by any ETL pipeline."""
|
||||
137
surfsense_backend/app/etl_pipeline/file_classifier.py
Normal file
137
surfsense_backend/app/etl_pipeline/file_classifier.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
from enum import Enum
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from app.utils.file_extensions import (
|
||||
DOCUMENT_EXTENSIONS,
|
||||
get_document_extensions_for_service,
|
||||
)
|
||||
|
||||
PLAINTEXT_EXTENSIONS = frozenset(
|
||||
{
|
||||
".md",
|
||||
".markdown",
|
||||
".txt",
|
||||
".text",
|
||||
".json",
|
||||
".jsonl",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".xml",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".sass",
|
||||
".py",
|
||||
".pyw",
|
||||
".pyi",
|
||||
".pyx",
|
||||
".js",
|
||||
".jsx",
|
||||
".ts",
|
||||
".tsx",
|
||||
".mjs",
|
||||
".cjs",
|
||||
".java",
|
||||
".kt",
|
||||
".kts",
|
||||
".scala",
|
||||
".groovy",
|
||||
".c",
|
||||
".h",
|
||||
".cpp",
|
||||
".cxx",
|
||||
".cc",
|
||||
".hpp",
|
||||
".hxx",
|
||||
".cs",
|
||||
".fs",
|
||||
".fsx",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".pl",
|
||||
".pm",
|
||||
".lua",
|
||||
".swift",
|
||||
".m",
|
||||
".mm",
|
||||
".r",
|
||||
".jl",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".bat",
|
||||
".cmd",
|
||||
".ps1",
|
||||
".sql",
|
||||
".graphql",
|
||||
".gql",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerignore",
|
||||
".editorconfig",
|
||||
".makefile",
|
||||
".cmake",
|
||||
".log",
|
||||
".rst",
|
||||
".tex",
|
||||
".bib",
|
||||
".org",
|
||||
".adoc",
|
||||
".asciidoc",
|
||||
".vue",
|
||||
".svelte",
|
||||
".astro",
|
||||
".tf",
|
||||
".hcl",
|
||||
".proto",
|
||||
}
|
||||
)
|
||||
|
||||
AUDIO_EXTENSIONS = frozenset(
|
||||
{".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm"}
|
||||
)
|
||||
|
||||
DIRECT_CONVERT_EXTENSIONS = frozenset({".csv", ".tsv", ".html", ".htm", ".xhtml"})
|
||||
|
||||
|
||||
class FileCategory(Enum):
|
||||
PLAINTEXT = "plaintext"
|
||||
AUDIO = "audio"
|
||||
DIRECT_CONVERT = "direct_convert"
|
||||
UNSUPPORTED = "unsupported"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
def classify_file(filename: str) -> FileCategory:
|
||||
suffix = PurePosixPath(filename).suffix.lower()
|
||||
if suffix in PLAINTEXT_EXTENSIONS:
|
||||
return FileCategory.PLAINTEXT
|
||||
if suffix in AUDIO_EXTENSIONS:
|
||||
return FileCategory.AUDIO
|
||||
if suffix in DIRECT_CONVERT_EXTENSIONS:
|
||||
return FileCategory.DIRECT_CONVERT
|
||||
if suffix in DOCUMENT_EXTENSIONS:
|
||||
return FileCategory.DOCUMENT
|
||||
return FileCategory.UNSUPPORTED
|
||||
|
||||
|
||||
def should_skip_for_service(filename: str, etl_service: str | None) -> bool:
|
||||
"""Return True if *filename* cannot be processed by *etl_service*.
|
||||
|
||||
Plaintext, audio, and direct-convert files are parser-agnostic and never
|
||||
skipped. Document files are checked against the per-parser extension set.
|
||||
"""
|
||||
category = classify_file(filename)
|
||||
if category == FileCategory.UNSUPPORTED:
|
||||
return True
|
||||
if category == FileCategory.DOCUMENT:
|
||||
suffix = PurePosixPath(filename).suffix.lower()
|
||||
return suffix not in get_document_extensions_for_service(etl_service)
|
||||
return False
|
||||
0
surfsense_backend/app/etl_pipeline/parsers/__init__.py
Normal file
0
surfsense_backend/app/etl_pipeline/parsers/__init__.py
Normal file
34
surfsense_backend/app/etl_pipeline/parsers/audio.py
Normal file
34
surfsense_backend/app/etl_pipeline/parsers/audio.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from litellm import atranscription
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
|
||||
async def transcribe_audio(file_path: str, filename: str) -> str:
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
result = stt_service.transcribe_file(file_path)
|
||||
text = result.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
response = await atranscription(**kwargs)
|
||||
text = response.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from app.tasks.document_processors._direct_converters import convert_file_directly
|
||||
|
||||
__all__ = ["convert_file_directly"]
|
||||
26
surfsense_backend/app/etl_pipeline/parsers/docling.py
Normal file
26
surfsense_backend/app/etl_pipeline/parsers/docling.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import warnings
|
||||
from logging import ERROR, getLogger
|
||||
|
||||
|
||||
async def parse_with_docling(file_path: str, filename: str) -> str:
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
|
||||
pdfminer_logger = getLogger("pdfminer")
|
||||
original_level = pdfminer_logger.level
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pdfminer")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".*Cannot set gray non-stroke color.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*invalid float value.*")
|
||||
pdfminer_logger.setLevel(ERROR)
|
||||
|
||||
try:
|
||||
result = await docling_service.process_document(file_path, filename)
|
||||
finally:
|
||||
pdfminer_logger.setLevel(original_level)
|
||||
|
||||
return result["content"]
|
||||
123
surfsense_backend/app/etl_pipeline/parsers/llamacloud.py
Normal file
123
surfsense_backend/app/etl_pipeline/parsers/llamacloud.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.etl_pipeline.constants import (
|
||||
LLAMACLOUD_BASE_DELAY,
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
LLAMACLOUD_MAX_RETRIES,
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS,
|
||||
PER_PAGE_JOB_TIMEOUT,
|
||||
calculate_job_timeout,
|
||||
calculate_upload_timeout,
|
||||
)
|
||||
|
||||
|
||||
async def parse_with_llamacloud(file_path: str, estimated_pages: int) -> str:
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
file_size_bytes = os.path.getsize(file_path)
|
||||
file_size_mb = file_size_bytes / (1024 * 1024)
|
||||
|
||||
upload_timeout = calculate_upload_timeout(file_size_bytes)
|
||||
job_timeout = calculate_job_timeout(estimated_pages, file_size_bytes)
|
||||
|
||||
custom_timeout = httpx.Timeout(
|
||||
connect=120.0,
|
||||
read=upload_timeout,
|
||||
write=upload_timeout,
|
||||
pool=120.0,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
|
||||
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
|
||||
f"job_timeout={job_timeout:.0f}s"
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
attempt_errors: list[str] = []
|
||||
|
||||
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1,
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
max_timeout=int(max(2000, job_timeout + upload_timeout)),
|
||||
job_timeout_in_seconds=job_timeout,
|
||||
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
|
||||
custom_client=custom_client,
|
||||
)
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
if attempt > 1:
|
||||
logging.info(
|
||||
f"LlamaCloud upload succeeded on attempt {attempt} after "
|
||||
f"{len(attempt_errors)} failures"
|
||||
)
|
||||
|
||||
if hasattr(result, "get_markdown_documents"):
|
||||
markdown_docs = result.get_markdown_documents(split_by_page=False)
|
||||
if markdown_docs and hasattr(markdown_docs[0], "text"):
|
||||
return markdown_docs[0].text
|
||||
if hasattr(result, "pages") and result.pages:
|
||||
return "\n\n".join(
|
||||
p.md for p in result.pages if hasattr(p, "md") and p.md
|
||||
)
|
||||
return str(result)
|
||||
|
||||
if isinstance(result, list):
|
||||
if result and hasattr(result[0], "text"):
|
||||
return result[0].text
|
||||
return "\n\n".join(
|
||||
doc.page_content if hasattr(doc, "page_content") else str(doc)
|
||||
for doc in result
|
||||
)
|
||||
|
||||
return str(result)
|
||||
|
||||
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
|
||||
last_exception = e
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)[:200]
|
||||
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
|
||||
|
||||
if attempt < LLAMACLOUD_MAX_RETRIES:
|
||||
base_delay = min(
|
||||
LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1)),
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
)
|
||||
jitter = base_delay * 0.25 * (2 * random.random() - 1)
|
||||
delay = base_delay + jitter
|
||||
|
||||
logging.warning(
|
||||
f"LlamaCloud upload failed "
|
||||
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): "
|
||||
f"{error_type}. File: {file_size_mb:.1f}MB. "
|
||||
f"Retrying in {delay:.0f}s..."
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logging.error(
|
||||
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} "
|
||||
f"attempts. File size: {file_size_mb:.1f}MB, "
|
||||
f"Pages: {estimated_pages}. "
|
||||
f"Errors: {'; '.join(attempt_errors)}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
raise last_exception or RuntimeError(
|
||||
f"LlamaCloud parsing failed after {LLAMACLOUD_MAX_RETRIES} retries. "
|
||||
f"File size: {file_size_mb:.1f}MB"
|
||||
)
|
||||
8
surfsense_backend/app/etl_pipeline/parsers/plaintext.py
Normal file
8
surfsense_backend/app/etl_pipeline/parsers/plaintext.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
def read_plaintext(file_path: str) -> str:
|
||||
with open(file_path, encoding="utf-8", errors="replace") as f:
|
||||
content = f.read()
|
||||
if "\x00" in content:
|
||||
raise ValueError(
|
||||
f"File contains null bytes — likely a binary file opened as text: {file_path}"
|
||||
)
|
||||
return content
|
||||
14
surfsense_backend/app/etl_pipeline/parsers/unstructured.py
Normal file
14
surfsense_backend/app/etl_pipeline/parsers/unstructured.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
async def parse_with_unstructured(file_path: str) -> str:
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return "\n\n".join(doc.page_content for doc in docs if doc.page_content)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -31,8 +31,11 @@ async def vision_autocomplete_stream(
|
|||
|
||||
return StreamingResponse(
|
||||
stream_vision_autocomplete(
|
||||
body.screenshot, body.search_space_id, session,
|
||||
app_name=body.app_name, window_title=body.window_title,
|
||||
body.screenshot,
|
||||
body.search_space_id,
|
||||
session,
|
||||
app_name=body.app_name,
|
||||
window_title=body.window_title,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
|
|||
|
|
@ -311,9 +311,11 @@ async def dropbox_callback(
|
|||
)
|
||||
|
||||
existing_cursor = db_connector.config.get("cursor")
|
||||
existing_folder_cursors = db_connector.config.get("folder_cursors")
|
||||
db_connector.config = {
|
||||
**connector_config,
|
||||
"cursor": existing_cursor,
|
||||
"folder_cursors": existing_folder_cursors,
|
||||
"auth_expired": False,
|
||||
}
|
||||
flag_modified(db_connector, "config")
|
||||
|
|
|
|||
|
|
@ -2477,6 +2477,8 @@ async def run_google_drive_indexing(
|
|||
stage="fetching",
|
||||
)
|
||||
|
||||
total_unsupported = 0
|
||||
|
||||
# Index each folder with indexing options
|
||||
for folder in items.folders:
|
||||
try:
|
||||
|
|
@ -2484,6 +2486,7 @@ async def run_google_drive_indexing(
|
|||
indexed_count,
|
||||
skipped_count,
|
||||
error_message,
|
||||
unsupported_count,
|
||||
) = await index_google_drive_files(
|
||||
session,
|
||||
connector_id,
|
||||
|
|
@ -2497,6 +2500,7 @@ async def run_google_drive_indexing(
|
|||
include_subfolders=indexing_options.include_subfolders,
|
||||
)
|
||||
total_skipped += skipped_count
|
||||
total_unsupported += unsupported_count
|
||||
if error_message:
|
||||
errors.append(f"Folder '{folder.name}': {error_message}")
|
||||
else:
|
||||
|
|
@ -2572,6 +2576,7 @@ async def run_google_drive_indexing(
|
|||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
unsupported_count=total_unsupported,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -2642,7 +2647,12 @@ async def run_onedrive_indexing(
|
|||
stage="fetching",
|
||||
)
|
||||
|
||||
total_indexed, total_skipped, error_message = await index_onedrive_files(
|
||||
(
|
||||
total_indexed,
|
||||
total_skipped,
|
||||
error_message,
|
||||
total_unsupported,
|
||||
) = await index_onedrive_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
|
|
@ -2683,6 +2693,7 @@ async def run_onedrive_indexing(
|
|||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
unsupported_count=total_unsupported,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -2750,7 +2761,12 @@ async def run_dropbox_indexing(
|
|||
stage="fetching",
|
||||
)
|
||||
|
||||
total_indexed, total_skipped, error_message = await index_dropbox_files(
|
||||
(
|
||||
total_indexed,
|
||||
total_skipped,
|
||||
error_message,
|
||||
total_unsupported,
|
||||
) = await index_dropbox_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
|
|
@ -2791,6 +2807,7 @@ async def run_dropbox_indexing(
|
|||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
unsupported_count=total_unsupported,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -111,9 +111,8 @@ class DoclingService:
|
|||
pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend
|
||||
)
|
||||
|
||||
# Initialize DocumentConverter
|
||||
self.converter = DocumentConverter(
|
||||
format_options={InputFormat.PDF: pdf_format_option}
|
||||
format_options={InputFormat.PDF: pdf_format_option},
|
||||
)
|
||||
|
||||
acceleration_type = "GPU (WSL2)" if self.use_gpu else "CPU"
|
||||
|
|
|
|||
|
|
@ -421,6 +421,7 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
|
|||
error_message: str | None = None,
|
||||
is_warning: bool = False,
|
||||
skipped_count: int | None = None,
|
||||
unsupported_count: int | None = None,
|
||||
) -> Notification:
|
||||
"""
|
||||
Update notification when connector indexing completes.
|
||||
|
|
@ -428,10 +429,11 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
|
|||
Args:
|
||||
session: Database session
|
||||
notification: Notification to update
|
||||
indexed_count: Total number of items indexed
|
||||
indexed_count: Total number of files indexed
|
||||
error_message: Error message if indexing failed, or warning message (optional)
|
||||
is_warning: If True, treat error_message as a warning (success case) rather than an error
|
||||
skipped_count: Number of items skipped (e.g., duplicates) - optional
|
||||
skipped_count: Number of files skipped (e.g., unchanged) - optional
|
||||
unsupported_count: Number of files skipped because the ETL parser doesn't support them
|
||||
|
||||
Returns:
|
||||
Updated notification
|
||||
|
|
@ -440,52 +442,45 @@ class ConnectorIndexingNotificationHandler(BaseNotificationHandler):
|
|||
"connector_name", "Connector"
|
||||
)
|
||||
|
||||
# Build the skipped text if there are skipped items
|
||||
skipped_text = ""
|
||||
if skipped_count and skipped_count > 0:
|
||||
skipped_item_text = "item" if skipped_count == 1 else "items"
|
||||
skipped_text = (
|
||||
f" ({skipped_count} {skipped_item_text} skipped - already indexed)"
|
||||
)
|
||||
unsupported_text = ""
|
||||
if unsupported_count and unsupported_count > 0:
|
||||
file_word = "file was" if unsupported_count == 1 else "files were"
|
||||
unsupported_text = f" {unsupported_count} {file_word} not supported."
|
||||
|
||||
# If there's an error message but items were indexed, treat it as a warning (partial success)
|
||||
# If is_warning is True, treat it as success even with 0 items (e.g., duplicates found)
|
||||
# Otherwise, treat it as a failure
|
||||
if error_message:
|
||||
if indexed_count > 0:
|
||||
# Partial success with warnings (e.g., duplicate content from other connectors)
|
||||
title = f"Ready: {connector_name}"
|
||||
item_text = "item" if indexed_count == 1 else "items"
|
||||
message = f"Now searchable! {indexed_count} {item_text} synced{skipped_text}. Note: {error_message}"
|
||||
file_text = "file" if indexed_count == 1 else "files"
|
||||
message = f"Now searchable! {indexed_count} {file_text} synced.{unsupported_text} Note: {error_message}"
|
||||
status = "completed"
|
||||
elif is_warning:
|
||||
# Warning case (e.g., duplicates found) - treat as success
|
||||
title = f"Ready: {connector_name}"
|
||||
message = f"Sync completed{skipped_text}. {error_message}"
|
||||
message = f"Sync complete.{unsupported_text} {error_message}"
|
||||
status = "completed"
|
||||
else:
|
||||
# Complete failure
|
||||
title = f"Failed: {connector_name}"
|
||||
message = f"Sync failed: {error_message}"
|
||||
if unsupported_text:
|
||||
message += unsupported_text
|
||||
status = "failed"
|
||||
else:
|
||||
title = f"Ready: {connector_name}"
|
||||
if indexed_count == 0:
|
||||
if skipped_count and skipped_count > 0:
|
||||
skipped_item_text = "item" if skipped_count == 1 else "items"
|
||||
message = f"Already up to date! {skipped_count} {skipped_item_text} skipped (already indexed)."
|
||||
if unsupported_count and unsupported_count > 0:
|
||||
message = f"Sync complete.{unsupported_text}"
|
||||
else:
|
||||
message = "Already up to date! No new items to sync."
|
||||
message = "Already up to date!"
|
||||
else:
|
||||
item_text = "item" if indexed_count == 1 else "items"
|
||||
message = (
|
||||
f"Now searchable! {indexed_count} {item_text} synced{skipped_text}."
|
||||
)
|
||||
file_text = "file" if indexed_count == 1 else "files"
|
||||
message = f"Now searchable! {indexed_count} {file_text} synced."
|
||||
if unsupported_text:
|
||||
message += unsupported_text
|
||||
status = "completed"
|
||||
|
||||
metadata_updates = {
|
||||
"indexed_count": indexed_count,
|
||||
"skipped_count": skipped_count or 0,
|
||||
"unsupported_count": unsupported_count or 0,
|
||||
"sync_stage": "completed"
|
||||
if (not error_message or is_warning or indexed_count > 0)
|
||||
else "failed",
|
||||
|
|
|
|||
|
|
@ -1,139 +1,40 @@
|
|||
import logging
|
||||
from typing import AsyncGenerator
|
||||
"""Vision autocomplete service — agent-based with scoped filesystem.
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
Optimized pipeline:
|
||||
1. Start the SSE stream immediately so the UI shows progress.
|
||||
2. Derive a KB search query from window_title (no separate LLM call).
|
||||
3. Run KB filesystem pre-computation and agent graph compilation in PARALLEL.
|
||||
4. Inject pre-computed KB files as initial state and stream the agent.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.agents.autocomplete import create_autocomplete_agent, stream_autocomplete_agent
|
||||
from app.services.llm_service import get_vision_llm
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KB_TOP_K = 5
|
||||
KB_MAX_CHARS = 4000
|
||||
|
||||
EXTRACT_QUERY_PROMPT = """Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
|
||||
|
||||
EXTRACT_QUERY_PROMPT_WITH_APP = """The user is currently in the application "{app_name}" with the window titled "{window_title}".
|
||||
|
||||
Look at this screenshot and describe in 1-2 short sentences what the user is working on and what topic they need to write about. Be specific about the subject matter. Output ONLY the description, nothing else."""
|
||||
|
||||
VISION_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text.
|
||||
|
||||
You will receive a screenshot of the user's screen. Your job:
|
||||
1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.).
|
||||
2. Identify the text area where the user will type.
|
||||
3. Based on the full visual context, generate the text the user most likely wants to write.
|
||||
|
||||
Key behavior:
|
||||
- If the text area is EMPTY, draft a full response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document).
|
||||
- If the text area already has text, continue it naturally.
|
||||
|
||||
Rules:
|
||||
- Output ONLY the text to be inserted. No quotes, no explanations, no meta-commentary.
|
||||
- Be concise but complete — a full thought, not a fragment.
|
||||
- Match the tone and formality of the surrounding context.
|
||||
- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal.
|
||||
- Do NOT describe the screenshot or explain your reasoning.
|
||||
- If you cannot determine what to write, output nothing."""
|
||||
|
||||
APP_CONTEXT_BLOCK = """
|
||||
|
||||
The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly."""
|
||||
|
||||
KB_CONTEXT_BLOCK = """
|
||||
|
||||
You also have access to the user's knowledge base documents below. Use them to write more accurate, informed, and contextually relevant text. Do NOT cite or reference the documents explicitly — just let the knowledge inform your writing naturally.
|
||||
|
||||
<knowledge_base>
|
||||
{kb_context}
|
||||
</knowledge_base>"""
|
||||
PREP_STEP_ID = "autocomplete-prep"
|
||||
|
||||
|
||||
def _build_system_prompt(app_name: str, window_title: str, kb_context: str) -> str:
|
||||
"""Assemble the system prompt from optional context blocks."""
|
||||
prompt = VISION_SYSTEM_PROMPT
|
||||
if app_name:
|
||||
prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title)
|
||||
if kb_context:
|
||||
prompt += KB_CONTEXT_BLOCK.format(kb_context=kb_context)
|
||||
return prompt
|
||||
def _derive_kb_query(app_name: str, window_title: str) -> str:
|
||||
parts = [p for p in (window_title, app_name) if p]
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _is_vision_unsupported_error(e: Exception) -> bool:
|
||||
"""Check if an exception indicates the model doesn't support vision/images."""
|
||||
msg = str(e).lower()
|
||||
return "content must be a string" in msg or "does not support image" in msg
|
||||
|
||||
|
||||
async def _extract_query_from_screenshot(
|
||||
llm, screenshot_data_url: str,
|
||||
app_name: str = "", window_title: str = "",
|
||||
) -> str | None:
|
||||
"""Ask the Vision LLM to describe what the user is working on.
|
||||
|
||||
Raises vision-unsupported errors so the caller can return a
|
||||
friendly message immediately instead of retrying with astream.
|
||||
"""
|
||||
if app_name:
|
||||
prompt_text = EXTRACT_QUERY_PROMPT_WITH_APP.format(
|
||||
app_name=app_name, window_title=window_title,
|
||||
)
|
||||
else:
|
||||
prompt_text = EXTRACT_QUERY_PROMPT
|
||||
|
||||
try:
|
||||
response = await llm.ainvoke([
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": prompt_text},
|
||||
{"type": "image_url", "image_url": {"url": screenshot_data_url}},
|
||||
]),
|
||||
])
|
||||
query = response.content.strip() if hasattr(response, "content") else ""
|
||||
return query if query else None
|
||||
except Exception as e:
|
||||
if _is_vision_unsupported_error(e):
|
||||
raise
|
||||
logger.warning(f"Failed to extract query from screenshot: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _search_knowledge_base(
|
||||
session: AsyncSession, search_space_id: int, query: str
|
||||
) -> str:
|
||||
"""Search the KB and return formatted context string."""
|
||||
try:
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text=query,
|
||||
top_k=KB_TOP_K,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
char_count = 0
|
||||
for doc in results:
|
||||
title = doc.get("document", {}).get("title", "Untitled")
|
||||
for chunk in doc.get("chunks", []):
|
||||
content = chunk.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
entry = f"[{title}]\n{content}"
|
||||
if char_count + len(entry) > KB_MAX_CHARS:
|
||||
break
|
||||
parts.append(entry)
|
||||
char_count += len(entry)
|
||||
if char_count >= KB_MAX_CHARS:
|
||||
break
|
||||
|
||||
return "\n\n---\n\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning(f"KB search failed, proceeding without context: {e}")
|
||||
return ""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def stream_vision_autocomplete(
|
||||
|
|
@ -144,13 +45,7 @@ async def stream_vision_autocomplete(
|
|||
app_name: str = "",
|
||||
window_title: str = "",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Analyze a screenshot with the vision LLM and stream a text completion.
|
||||
|
||||
Pipeline:
|
||||
1. Extract a search query from the screenshot (non-streaming)
|
||||
2. Search the knowledge base for relevant context
|
||||
3. Stream the final completion with screenshot + KB + app context
|
||||
"""
|
||||
"""Analyze a screenshot with a vision-LLM agent and stream a text completion."""
|
||||
streaming = VercelStreamingService()
|
||||
vision_error_msg = (
|
||||
"The selected model does not support vision. "
|
||||
|
|
@ -164,62 +59,100 @@ async def stream_vision_autocomplete(
|
|||
yield streaming.format_done()
|
||||
return
|
||||
|
||||
kb_context = ""
|
||||
# Start SSE stream immediately so the UI has something to show
|
||||
yield streaming.format_message_start()
|
||||
|
||||
kb_query = _derive_kb_query(app_name, window_title)
|
||||
|
||||
# Show a preparation step while KB search + agent compile run
|
||||
yield streaming.format_thinking_step(
|
||||
step_id=PREP_STEP_ID,
|
||||
title="Searching knowledge base",
|
||||
status="in_progress",
|
||||
items=[kb_query] if kb_query else [],
|
||||
)
|
||||
|
||||
try:
|
||||
query = await _extract_query_from_screenshot(
|
||||
llm, screenshot_data_url, app_name=app_name, window_title=window_title,
|
||||
agent, kb = await create_autocomplete_agent(
|
||||
llm,
|
||||
search_space_id=search_space_id,
|
||||
kb_query=kb_query,
|
||||
app_name=app_name,
|
||||
window_title=window_title,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
|
||||
yield streaming.format_message_start()
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
if _is_vision_unsupported_error(e):
|
||||
logger.warning("Vision autocomplete: model does not support vision: %s", e)
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
yield streaming.format_done()
|
||||
return
|
||||
logger.error("Failed to create autocomplete agent: %s", e, exc_info=True)
|
||||
yield streaming.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming.format_done()
|
||||
return
|
||||
|
||||
if query:
|
||||
kb_context = await _search_knowledge_base(session, search_space_id, query)
|
||||
has_kb = kb.has_documents
|
||||
doc_count = len(kb.files) if has_kb else 0 # type: ignore[arg-type]
|
||||
|
||||
system_prompt = _build_system_prompt(app_name, window_title, kb_context)
|
||||
yield streaming.format_thinking_step(
|
||||
step_id=PREP_STEP_ID,
|
||||
title="Searching knowledge base",
|
||||
status="complete",
|
||||
items=[f"Found {doc_count} document{'s' if doc_count != 1 else ''}"]
|
||||
if kb_query
|
||||
else ["Skipped"],
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Analyze this screenshot. Understand the full context of what the user is working on, then generate the text they most likely want to write in the active text area.",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": screenshot_data_url},
|
||||
},
|
||||
]),
|
||||
]
|
||||
# Build agent input with pre-computed KB as initial state
|
||||
if has_kb:
|
||||
instruction = (
|
||||
"Analyze this screenshot, then explore the knowledge base documents "
|
||||
"listed above — read the chunk index of any document whose title "
|
||||
"looks relevant and check matched chunks for useful facts. "
|
||||
"Finally, generate a concise autocomplete for the active text area, "
|
||||
"enhanced with any relevant KB information you found."
|
||||
)
|
||||
else:
|
||||
instruction = (
|
||||
"Analyze this screenshot and generate a concise autocomplete "
|
||||
"for the active text area based on what you see."
|
||||
)
|
||||
|
||||
text_started = False
|
||||
text_id = ""
|
||||
user_message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": instruction},
|
||||
{"type": "image_url", "image_url": {"url": screenshot_data_url}},
|
||||
]
|
||||
)
|
||||
|
||||
input_data: dict = {"messages": [user_message]}
|
||||
|
||||
if has_kb:
|
||||
input_data["files"] = kb.files
|
||||
input_data["messages"] = [kb.ls_ai_msg, kb.ls_tool_msg, user_message]
|
||||
logger.info(
|
||||
"Autocomplete: injected %d KB files into agent initial state", doc_count
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Autocomplete: no KB documents found, proceeding with screenshot only"
|
||||
)
|
||||
|
||||
# Stream the agent (message_start already sent above)
|
||||
try:
|
||||
yield streaming.format_message_start()
|
||||
text_id = streaming.generate_text_id()
|
||||
yield streaming.format_text_start(text_id)
|
||||
text_started = True
|
||||
|
||||
async for chunk in llm.astream(messages):
|
||||
token = chunk.content if hasattr(chunk, "content") else str(chunk)
|
||||
if token:
|
||||
yield streaming.format_text_delta(text_id, token)
|
||||
|
||||
yield streaming.format_text_end(text_id)
|
||||
yield streaming.format_finish()
|
||||
yield streaming.format_done()
|
||||
|
||||
async for sse in stream_autocomplete_agent(
|
||||
agent,
|
||||
input_data,
|
||||
streaming,
|
||||
emit_message_start=False,
|
||||
):
|
||||
yield sse
|
||||
except Exception as e:
|
||||
if text_started:
|
||||
yield streaming.format_text_end(text_id)
|
||||
|
||||
if _is_vision_unsupported_error(e):
|
||||
logger.warning(f"Vision autocomplete: selected model does not support vision: {e}")
|
||||
logger.warning("Vision autocomplete: model does not support vision: %s", e)
|
||||
yield streaming.format_error(vision_error_msg)
|
||||
yield streaming.format_done()
|
||||
else:
|
||||
logger.error(f"Vision autocomplete streaming error: {e}", exc_info=True)
|
||||
logger.error("Vision autocomplete streaming error: %s", e, exc_info=True)
|
||||
yield streaming.format_error("Autocomplete failed. Please try again.")
|
||||
yield streaming.format_done()
|
||||
yield streaming.format_done()
|
||||
|
|
|
|||
|
|
@ -51,7 +51,10 @@ async def _should_skip_file(
|
|||
file_id = file.get("id", "")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if skip_item(file):
|
||||
skip, unsup_ext = skip_item(file)
|
||||
if skip:
|
||||
if unsup_ext:
|
||||
return True, f"unsupported:{unsup_ext}"
|
||||
return True, "folder/non-downloadable"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
|
@ -251,6 +254,121 @@ async def _download_and_index(
|
|||
return batch_indexed, download_failed + batch_failed
|
||||
|
||||
|
||||
async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int):
|
||||
"""Remove a document that was deleted in Dropbox."""
|
||||
primary_hash = compute_identifier_hash(
|
||||
DocumentType.DROPBOX_FILE.value, file_id, search_space_id
|
||||
)
|
||||
existing = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
cast(Document.document_metadata["dropbox_file_id"], String) == file_id,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
await session.delete(existing)
|
||||
|
||||
|
||||
async def _index_with_delta_sync(
|
||||
dropbox_client: DropboxClient,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
cursor: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: object,
|
||||
max_files: int,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int, int, str]:
|
||||
"""Delta sync using Dropbox cursor-based change tracking.
|
||||
|
||||
Returns (indexed_count, skipped_count, new_cursor).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting delta sync from cursor: {cursor[:20]}...",
|
||||
{"stage": "delta_sync", "cursor_prefix": cursor[:20]},
|
||||
)
|
||||
|
||||
entries, new_cursor, error = await dropbox_client.get_changes(cursor)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if "401" in error or "authentication expired" in err_lower:
|
||||
raise Exception(
|
||||
f"Dropbox authentication failed. Please re-authenticate. (Error: {error})"
|
||||
)
|
||||
raise Exception(f"Failed to fetch Dropbox changes: {error}")
|
||||
|
||||
if not entries:
|
||||
logger.info("No changes detected since last sync")
|
||||
return 0, 0, 0, new_cursor or cursor
|
||||
|
||||
logger.info(f"Processing {len(entries)} change entries")
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
|
||||
for entry in entries:
|
||||
if files_processed >= max_files:
|
||||
break
|
||||
files_processed += 1
|
||||
|
||||
tag = entry.get(".tag")
|
||||
|
||||
if tag == "deleted":
|
||||
path_lower = entry.get("path_lower", "")
|
||||
name = entry.get("name", "")
|
||||
file_id = entry.get("id", "")
|
||||
if file_id:
|
||||
await _remove_document(session, file_id, search_space_id)
|
||||
logger.debug(f"Processed deletion: {name or path_lower}")
|
||||
continue
|
||||
|
||||
if tag != "file":
|
||||
continue
|
||||
|
||||
skip, msg = await _should_skip_file(session, entry, search_space_id)
|
||||
if skip:
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
files_to_download.append(entry)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
dropbox_client,
|
||||
session,
|
||||
files_to_download,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped, unsupported_count, new_cursor or cursor
|
||||
|
||||
|
||||
async def _index_full_scan(
|
||||
dropbox_client: DropboxClient,
|
||||
session: AsyncSession,
|
||||
|
|
@ -266,8 +384,11 @@ async def _index_full_scan(
|
|||
incremental_sync: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
) -> tuple[int, int, int]:
|
||||
"""Full scan indexing of a folder.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name}",
|
||||
|
|
@ -287,6 +408,7 @@ async def _index_full_scan(
|
|||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
||||
all_files, error = await get_files_in_folder(
|
||||
|
|
@ -306,14 +428,21 @@ async def _index_full_scan(
|
|||
if incremental_sync:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
elif skip_item(file):
|
||||
skipped += 1
|
||||
continue
|
||||
else:
|
||||
item_skip, item_unsup = skip_item(file)
|
||||
if item_skip:
|
||||
if item_unsup:
|
||||
unsupported_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
|
|
@ -352,9 +481,10 @@ async def _index_full_scan(
|
|||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
return indexed, skipped, unsupported_count
|
||||
|
||||
|
||||
async def _index_selected_files(
|
||||
|
|
@ -368,7 +498,7 @@ async def _index_selected_files(
|
|||
enable_summary: bool,
|
||||
incremental_sync: bool = True,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
|
|
@ -379,6 +509,7 @@ async def _index_selected_files(
|
|||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
|
||||
for file_path, file_name in file_paths:
|
||||
file, error = await get_file_by_path(dropbox_client, file_path)
|
||||
|
|
@ -390,14 +521,21 @@ async def _index_selected_files(
|
|||
if incremental_sync:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
elif skip_item(file):
|
||||
skipped += 1
|
||||
continue
|
||||
else:
|
||||
item_skip, item_unsup = skip_item(file)
|
||||
if item_skip:
|
||||
if item_unsup:
|
||||
unsupported_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
file_pages = PageLimitService.estimate_pages_from_metadata(
|
||||
file.get("name", ""), file.get("size")
|
||||
|
|
@ -429,7 +567,7 @@ async def _index_selected_files(
|
|||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
return renamed_count + batch_indexed, skipped, unsupported_count, errors
|
||||
|
||||
|
||||
async def index_dropbox_files(
|
||||
|
|
@ -438,7 +576,7 @@ async def index_dropbox_files(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
) -> tuple[int, int, str | None]:
|
||||
) -> tuple[int, int, str | None, int]:
|
||||
"""Index Dropbox files for a specific connector.
|
||||
|
||||
items_dict format:
|
||||
|
|
@ -469,7 +607,7 @@ async def index_dropbox_files(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
|
|
@ -480,7 +618,7 @@ async def index_dropbox_files(
|
|||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
dropbox_client = DropboxClient(session, connector_id)
|
||||
|
|
@ -489,9 +627,13 @@ async def index_dropbox_files(
|
|||
max_files = indexing_options.get("max_files", 500)
|
||||
incremental_sync = indexing_options.get("incremental_sync", True)
|
||||
include_subfolders = indexing_options.get("include_subfolders", True)
|
||||
use_delta_sync = indexing_options.get("use_delta_sync", True)
|
||||
|
||||
folder_cursors: dict = connector.config.get("folder_cursors", {})
|
||||
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
total_unsupported = 0
|
||||
|
||||
selected_files = items_dict.get("files", [])
|
||||
if selected_files:
|
||||
|
|
@ -499,7 +641,7 @@ async def index_dropbox_files(
|
|||
(f.get("path", f.get("path_lower", f.get("id", ""))), f.get("name"))
|
||||
for f in selected_files
|
||||
]
|
||||
indexed, skipped, file_errors = await _index_selected_files(
|
||||
indexed, skipped, unsupported, file_errors = await _index_selected_files(
|
||||
dropbox_client,
|
||||
session,
|
||||
file_tuples,
|
||||
|
|
@ -511,6 +653,7 @@ async def index_dropbox_files(
|
|||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
total_unsupported += unsupported
|
||||
if file_errors:
|
||||
logger.warning(
|
||||
f"File indexing errors for connector {connector_id}: {file_errors}"
|
||||
|
|
@ -523,25 +666,66 @@ async def index_dropbox_files(
|
|||
)
|
||||
folder_name = folder.get("name", "Root")
|
||||
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped = await _index_full_scan(
|
||||
dropbox_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
folder_path,
|
||||
folder_name,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
include_subfolders,
|
||||
incremental_sync=incremental_sync,
|
||||
enable_summary=connector_enable_summary,
|
||||
saved_cursor = folder_cursors.get(folder_path)
|
||||
can_use_delta = (
|
||||
use_delta_sync and saved_cursor and connector.last_indexed_at
|
||||
)
|
||||
|
||||
if can_use_delta:
|
||||
logger.info(f"Using delta sync for folder {folder_name}")
|
||||
indexed, skipped, unsup, new_cursor = await _index_with_delta_sync(
|
||||
dropbox_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
saved_cursor,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
folder_cursors[folder_path] = new_cursor
|
||||
total_unsupported += unsup
|
||||
else:
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped, unsup = await _index_full_scan(
|
||||
dropbox_client,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
folder_path,
|
||||
folder_name,
|
||||
task_logger,
|
||||
log_entry,
|
||||
max_files,
|
||||
include_subfolders,
|
||||
incremental_sync=incremental_sync,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_unsupported += unsup
|
||||
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
# Persist latest cursor for this folder
|
||||
try:
|
||||
latest_cursor, cursor_err = await dropbox_client.get_latest_cursor(
|
||||
folder_path
|
||||
)
|
||||
if latest_cursor and not cursor_err:
|
||||
folder_cursors[folder_path] = latest_cursor
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get latest cursor for {folder_path}: {e}")
|
||||
|
||||
# Persist folder cursors to connector config
|
||||
if folders:
|
||||
cfg = dict(connector.config)
|
||||
cfg["folder_cursors"] = folder_cursors
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
|
||||
if total_indexed > 0 or folders:
|
||||
await update_connector_last_indexed(session, connector, True)
|
||||
|
||||
|
|
@ -550,12 +734,18 @@ async def index_dropbox_files(
|
|||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed Dropbox indexing for connector {connector_id}",
|
||||
{"files_processed": total_indexed, "files_skipped": total_skipped},
|
||||
{
|
||||
"files_processed": total_indexed,
|
||||
"files_skipped": total_skipped,
|
||||
"files_unsupported": total_unsupported,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Dropbox indexing completed: {total_indexed} indexed, {total_skipped} skipped"
|
||||
f"Dropbox indexing completed: {total_indexed} indexed, "
|
||||
f"{total_skipped} skipped, {total_unsupported} unsupported"
|
||||
)
|
||||
return total_indexed, total_skipped, None
|
||||
|
||||
return total_indexed, total_skipped, None, total_unsupported
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -566,7 +756,7 @@ async def index_dropbox_files(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}", 0
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -576,4 +766,4 @@ async def index_dropbox_files(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Dropbox files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index Dropbox files: {e!s}"
|
||||
return 0, 0, f"Failed to index Dropbox files: {e!s}", 0
|
||||
|
|
|
|||
|
|
@ -25,7 +25,11 @@ from app.connectors.google_drive import (
|
|||
get_files_in_folder,
|
||||
get_start_page_token,
|
||||
)
|
||||
from app.connectors.google_drive.file_types import should_skip_file as skip_mime
|
||||
from app.connectors.google_drive.file_types import (
|
||||
is_google_workspace_file,
|
||||
should_skip_by_extension,
|
||||
should_skip_file as skip_mime,
|
||||
)
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
|
|
@ -78,6 +82,10 @@ async def _should_skip_file(
|
|||
|
||||
if skip_mime(mime_type):
|
||||
return True, "folder/shortcut"
|
||||
if not is_google_workspace_file(mime_type):
|
||||
ext_skip, unsup_ext = should_skip_by_extension(file_name)
|
||||
if ext_skip:
|
||||
return True, f"unsupported:{unsup_ext}"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
||||
|
|
@ -468,13 +476,13 @@ async def _index_selected_files(
|
|||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline.
|
||||
|
||||
Phase 1 (serial): fetch metadata + skip checks.
|
||||
Phase 2+3 (parallel): download, ETL, index via _download_and_index.
|
||||
|
||||
Returns (indexed_count, skipped_count, errors).
|
||||
Returns (indexed_count, skipped_count, unsupported_count, errors).
|
||||
"""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
|
|
@ -485,6 +493,7 @@ async def _index_selected_files(
|
|||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
|
||||
for file_id, file_name in file_ids:
|
||||
file, error = await get_file_by_id(drive_client, file_id)
|
||||
|
|
@ -495,7 +504,9 @@ async def _index_selected_files(
|
|||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
|
@ -539,7 +550,7 @@ async def _index_selected_files(
|
|||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
return renamed_count + batch_indexed, skipped, unsupported_count, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -562,8 +573,11 @@ async def _index_full_scan(
|
|||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
) -> tuple[int, int, int]:
|
||||
"""Full scan indexing of a folder.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})",
|
||||
|
|
@ -585,6 +599,7 @@ async def _index_full_scan(
|
|||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_processed = 0
|
||||
files_to_download: list[dict] = []
|
||||
folders_to_process = [(folder_id, folder_name)]
|
||||
|
|
@ -625,7 +640,9 @@ async def _index_full_scan(
|
|||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
|
@ -698,9 +715,10 @@ async def _index_full_scan(
|
|||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
return indexed, skipped, unsupported_count
|
||||
|
||||
|
||||
async def _index_with_delta_sync(
|
||||
|
|
@ -718,8 +736,11 @@ async def _index_with_delta_sync(
|
|||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Delta sync using change tracking."""
|
||||
) -> tuple[int, int, int]:
|
||||
"""Delta sync using change tracking.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting delta sync from token: {start_page_token[:20]}...",
|
||||
|
|
@ -739,7 +760,7 @@ async def _index_with_delta_sync(
|
|||
|
||||
if not changes:
|
||||
logger.info("No changes detected since last sync")
|
||||
return 0, 0
|
||||
return 0, 0, 0
|
||||
|
||||
logger.info(f"Processing {len(changes)} changes")
|
||||
|
||||
|
|
@ -754,6 +775,7 @@ async def _index_with_delta_sync(
|
|||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
|
||||
|
|
@ -775,7 +797,9 @@ async def _index_with_delta_sync(
|
|||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
|
@ -832,9 +856,10 @@ async def _index_with_delta_sync(
|
|||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
return indexed, skipped, unsupported_count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -854,8 +879,11 @@ async def index_google_drive_files(
|
|||
max_files: int = 500,
|
||||
include_subfolders: bool = False,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Index Google Drive files for a specific connector."""
|
||||
) -> tuple[int, int, str | None, int]:
|
||||
"""Index Google Drive files for a specific connector.
|
||||
|
||||
Returns (indexed, skipped, error_or_none, unsupported_count).
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="google_drive_files_indexing",
|
||||
|
|
@ -881,7 +909,7 @@ async def index_google_drive_files(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
|
|
@ -900,7 +928,7 @@ async def index_google_drive_files(
|
|||
"Missing Composio account",
|
||||
{"error_type": "MissingComposioAccount"},
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||
else:
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
|
@ -915,6 +943,7 @@ async def index_google_drive_files(
|
|||
0,
|
||||
0,
|
||||
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||
0,
|
||||
)
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
|
|
@ -927,7 +956,7 @@ async def index_google_drive_files(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, {"error_type": "MissingParameter"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
target_folder_id = folder_id
|
||||
target_folder_name = folder_name or "Selected Folder"
|
||||
|
|
@ -938,9 +967,11 @@ async def index_google_drive_files(
|
|||
use_delta_sync and start_page_token and connector.last_indexed_at
|
||||
)
|
||||
|
||||
documents_unsupported = 0
|
||||
|
||||
if can_use_delta:
|
||||
logger.info(f"Using delta sync for connector {connector_id}")
|
||||
documents_indexed, documents_skipped = await _index_with_delta_sync(
|
||||
documents_indexed, documents_skipped, du = await _index_with_delta_sync(
|
||||
drive_client,
|
||||
session,
|
||||
connector,
|
||||
|
|
@ -956,8 +987,9 @@ async def index_google_drive_files(
|
|||
on_heartbeat_callback,
|
||||
connector_enable_summary,
|
||||
)
|
||||
documents_unsupported += du
|
||||
logger.info("Running reconciliation scan after delta sync")
|
||||
ri, rs = await _index_full_scan(
|
||||
ri, rs, ru = await _index_full_scan(
|
||||
drive_client,
|
||||
session,
|
||||
connector,
|
||||
|
|
@ -975,9 +1007,14 @@ async def index_google_drive_files(
|
|||
)
|
||||
documents_indexed += ri
|
||||
documents_skipped += rs
|
||||
documents_unsupported += ru
|
||||
else:
|
||||
logger.info(f"Using full scan for connector {connector_id}")
|
||||
documents_indexed, documents_skipped = await _index_full_scan(
|
||||
(
|
||||
documents_indexed,
|
||||
documents_skipped,
|
||||
documents_unsupported,
|
||||
) = await _index_full_scan(
|
||||
drive_client,
|
||||
session,
|
||||
connector,
|
||||
|
|
@ -1012,14 +1049,17 @@ async def index_google_drive_files(
|
|||
{
|
||||
"files_processed": documents_indexed,
|
||||
"files_skipped": documents_skipped,
|
||||
"files_unsupported": documents_unsupported,
|
||||
"sync_type": "delta" if can_use_delta else "full",
|
||||
"folder": target_folder_name,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Google Drive indexing completed: {documents_indexed} indexed, {documents_skipped} skipped"
|
||||
f"Google Drive indexing completed: {documents_indexed} indexed, "
|
||||
f"{documents_skipped} skipped, {documents_unsupported} unsupported"
|
||||
)
|
||||
return documents_indexed, documents_skipped, None
|
||||
|
||||
return documents_indexed, documents_skipped, None, documents_unsupported
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -1030,7 +1070,7 @@ async def index_google_drive_files(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}", 0
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -1040,7 +1080,7 @@ async def index_google_drive_files(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index Google Drive files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index Google Drive files: {e!s}"
|
||||
return 0, 0, f"Failed to index Google Drive files: {e!s}", 0
|
||||
|
||||
|
||||
async def index_google_drive_single_file(
|
||||
|
|
@ -1242,7 +1282,7 @@ async def index_google_drive_selected_files(
|
|||
session, connector_id, credentials=pre_built_credentials
|
||||
)
|
||||
|
||||
indexed, skipped, errors = await _index_selected_files(
|
||||
indexed, skipped, unsupported, errors = await _index_selected_files(
|
||||
drive_client,
|
||||
session,
|
||||
files,
|
||||
|
|
@ -1253,6 +1293,11 @@ async def index_google_drive_selected_files(
|
|||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
if unsupported > 0:
|
||||
file_text = "file was" if unsupported == 1 else "files were"
|
||||
unsup_msg = f"{unsupported} {file_text} not supported"
|
||||
errors.append(unsup_msg)
|
||||
|
||||
await session.commit()
|
||||
|
||||
if errors:
|
||||
|
|
@ -1260,7 +1305,12 @@ async def index_google_drive_selected_files(
|
|||
log_entry,
|
||||
f"Batch file indexing completed with {len(errors)} error(s)",
|
||||
"; ".join(errors),
|
||||
{"indexed": indexed, "skipped": skipped, "error_count": len(errors)},
|
||||
{
|
||||
"indexed": indexed,
|
||||
"skipped": skipped,
|
||||
"unsupported": unsupported,
|
||||
"error_count": len(errors),
|
||||
},
|
||||
)
|
||||
else:
|
||||
await task_logger.log_task_success(
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from sqlalchemy import select
|
|||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentStatus,
|
||||
|
|
@ -44,132 +43,6 @@ from .base import (
|
|||
logger,
|
||||
)
|
||||
|
||||
PLAINTEXT_EXTENSIONS = frozenset(
|
||||
{
|
||||
".md",
|
||||
".markdown",
|
||||
".txt",
|
||||
".text",
|
||||
".json",
|
||||
".jsonl",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".xml",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".sass",
|
||||
".py",
|
||||
".pyw",
|
||||
".pyi",
|
||||
".pyx",
|
||||
".js",
|
||||
".jsx",
|
||||
".ts",
|
||||
".tsx",
|
||||
".mjs",
|
||||
".cjs",
|
||||
".java",
|
||||
".kt",
|
||||
".kts",
|
||||
".scala",
|
||||
".groovy",
|
||||
".c",
|
||||
".h",
|
||||
".cpp",
|
||||
".cxx",
|
||||
".cc",
|
||||
".hpp",
|
||||
".hxx",
|
||||
".cs",
|
||||
".fs",
|
||||
".fsx",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".pl",
|
||||
".pm",
|
||||
".lua",
|
||||
".swift",
|
||||
".m",
|
||||
".mm",
|
||||
".r",
|
||||
".R",
|
||||
".jl",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".bat",
|
||||
".cmd",
|
||||
".ps1",
|
||||
".sql",
|
||||
".graphql",
|
||||
".gql",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerignore",
|
||||
".editorconfig",
|
||||
".makefile",
|
||||
".cmake",
|
||||
".log",
|
||||
".rst",
|
||||
".tex",
|
||||
".bib",
|
||||
".org",
|
||||
".adoc",
|
||||
".asciidoc",
|
||||
".vue",
|
||||
".svelte",
|
||||
".astro",
|
||||
".tf",
|
||||
".hcl",
|
||||
".proto",
|
||||
}
|
||||
)
|
||||
|
||||
AUDIO_EXTENSIONS = frozenset(
|
||||
{
|
||||
".mp3",
|
||||
".mp4",
|
||||
".mpeg",
|
||||
".mpga",
|
||||
".m4a",
|
||||
".wav",
|
||||
".webm",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
DIRECT_CONVERT_EXTENSIONS = frozenset({".csv", ".tsv", ".html", ".htm"})
|
||||
|
||||
|
||||
def _is_plaintext_file(filename: str) -> bool:
|
||||
return Path(filename).suffix.lower() in PLAINTEXT_EXTENSIONS
|
||||
|
||||
|
||||
def _is_audio_file(filename: str) -> bool:
|
||||
return Path(filename).suffix.lower() in AUDIO_EXTENSIONS
|
||||
|
||||
|
||||
def _is_direct_convert_file(filename: str) -> bool:
|
||||
return Path(filename).suffix.lower() in DIRECT_CONVERT_EXTENSIONS
|
||||
|
||||
|
||||
def _needs_etl(filename: str) -> bool:
|
||||
"""File is not plaintext, not audio, and not direct-convert — requires ETL."""
|
||||
return (
|
||||
not _is_plaintext_file(filename)
|
||||
and not _is_audio_file(filename)
|
||||
and not _is_direct_convert_file(filename)
|
||||
)
|
||||
|
||||
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
|
||||
|
||||
|
|
@ -279,57 +152,19 @@ def scan_folder(
|
|||
return files
|
||||
|
||||
|
||||
def _read_plaintext_file(file_path: str) -> str:
|
||||
"""Read a plaintext/text-based file as UTF-8."""
|
||||
with open(file_path, encoding="utf-8", errors="replace") as f:
|
||||
content = f.read()
|
||||
if "\x00" in content:
|
||||
raise ValueError(
|
||||
f"File contains null bytes — likely a binary file opened as text: {file_path}"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
async def _read_file_content(file_path: str, filename: str) -> str:
|
||||
"""Read file content, using ETL for binary formats.
|
||||
"""Read file content via the unified ETL pipeline.
|
||||
|
||||
Plaintext files are read directly. Audio and document files (PDF, DOCX, etc.)
|
||||
are routed through the configured ETL service (same as Google Drive / OneDrive).
|
||||
|
||||
Raises ValueError if the file cannot be parsed (e.g. no ETL service configured
|
||||
for a binary file).
|
||||
All file types (plaintext, audio, direct-convert, document) are handled
|
||||
by ``EtlPipelineService``.
|
||||
"""
|
||||
if _is_plaintext_file(filename):
|
||||
return _read_plaintext_file(file_path)
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
if _is_direct_convert_file(filename):
|
||||
from app.tasks.document_processors._direct_converters import (
|
||||
convert_file_directly,
|
||||
)
|
||||
|
||||
return convert_file_directly(file_path, filename)
|
||||
|
||||
if _is_audio_file(filename):
|
||||
etl_service = config.ETL_SERVICE if hasattr(config, "ETL_SERVICE") else None
|
||||
stt_service_val = config.STT_SERVICE if hasattr(config, "STT_SERVICE") else None
|
||||
if not stt_service_val and not etl_service:
|
||||
raise ValueError(
|
||||
f"No STT_SERVICE configured — cannot transcribe audio file: {filename}"
|
||||
)
|
||||
|
||||
if _needs_etl(filename):
|
||||
etl_service = getattr(config, "ETL_SERVICE", None)
|
||||
if not etl_service:
|
||||
raise ValueError(
|
||||
f"No ETL_SERVICE configured — cannot parse binary file: {filename}. "
|
||||
f"Set ETL_SERVICE to UNSTRUCTURED, LLAMACLOUD, or DOCLING in your .env"
|
||||
)
|
||||
|
||||
from app.connectors.onedrive.content_extractor import (
|
||||
_parse_file_to_markdown,
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=file_path, filename=filename)
|
||||
)
|
||||
|
||||
return await _parse_file_to_markdown(file_path, filename)
|
||||
return result.markdown_content
|
||||
|
||||
|
||||
def _content_hash(content: str, search_space_id: int) -> str:
|
||||
|
|
|
|||
|
|
@ -56,7 +56,10 @@ async def _should_skip_file(
|
|||
file_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if skip_item(file):
|
||||
skip, unsup_ext = skip_item(file)
|
||||
if skip:
|
||||
if unsup_ext:
|
||||
return True, f"unsupported:{unsup_ext}"
|
||||
return True, "folder/onenote/remote"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
|
@ -290,7 +293,7 @@ async def _index_selected_files(
|
|||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
) -> tuple[int, int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
page_limit_service = PageLimitService(session)
|
||||
pages_used, pages_limit = await page_limit_service.get_page_usage(user_id)
|
||||
|
|
@ -301,6 +304,7 @@ async def _index_selected_files(
|
|||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
|
||||
for file_id, file_name in file_ids:
|
||||
file, error = await get_file_by_id(onedrive_client, file_id)
|
||||
|
|
@ -311,7 +315,9 @@ async def _index_selected_files(
|
|||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
|
@ -347,7 +353,7 @@ async def _index_selected_files(
|
|||
user_id, pages_to_deduct, allow_exceed=True
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
return renamed_count + batch_indexed, skipped, unsupported_count, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -369,8 +375,11 @@ async def _index_full_scan(
|
|||
include_subfolders: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
) -> tuple[int, int, int]:
|
||||
"""Full scan indexing of a folder.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name}",
|
||||
|
|
@ -389,6 +398,7 @@ async def _index_full_scan(
|
|||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
||||
all_files, error = await get_files_in_folder(
|
||||
|
|
@ -407,7 +417,9 @@ async def _index_full_scan(
|
|||
for file in all_files[:max_files]:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
|
@ -450,9 +462,10 @@ async def _index_full_scan(
|
|||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Full scan complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped
|
||||
return indexed, skipped, unsupported_count
|
||||
|
||||
|
||||
async def _index_with_delta_sync(
|
||||
|
|
@ -468,8 +481,11 @@ async def _index_with_delta_sync(
|
|||
max_files: int,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Delta sync using OneDrive change tracking. Returns (indexed, skipped, new_delta_link)."""
|
||||
) -> tuple[int, int, int, str | None]:
|
||||
"""Delta sync using OneDrive change tracking.
|
||||
|
||||
Returns (indexed, skipped, unsupported_count, new_delta_link).
|
||||
"""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
"Starting delta sync",
|
||||
|
|
@ -489,7 +505,7 @@ async def _index_with_delta_sync(
|
|||
|
||||
if not changes:
|
||||
logger.info("No changes detected since last sync")
|
||||
return 0, 0, new_delta_link
|
||||
return 0, 0, 0, new_delta_link
|
||||
|
||||
logger.info(f"Processing {len(changes)} delta changes")
|
||||
|
||||
|
|
@ -501,6 +517,7 @@ async def _index_with_delta_sync(
|
|||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
unsupported_count = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
|
||||
|
|
@ -523,7 +540,9 @@ async def _index_with_delta_sync(
|
|||
|
||||
skip, msg = await _should_skip_file(session, change, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
if msg and msg.startswith("unsupported:"):
|
||||
unsupported_count += 1
|
||||
elif msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
|
@ -566,9 +585,10 @@ async def _index_with_delta_sync(
|
|||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||
f"Delta sync complete: {indexed} indexed, {skipped} skipped, "
|
||||
f"{unsupported_count} unsupported, {failed} failed"
|
||||
)
|
||||
return indexed, skipped, new_delta_link
|
||||
return indexed, skipped, unsupported_count, new_delta_link
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -582,7 +602,7 @@ async def index_onedrive_files(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
) -> tuple[int, int, str | None]:
|
||||
) -> tuple[int, int, str | None, int]:
|
||||
"""Index OneDrive files for a specific connector.
|
||||
|
||||
items_dict format:
|
||||
|
|
@ -609,7 +629,7 @@ async def index_onedrive_files(
|
|||
await task_logger.log_task_failure(
|
||||
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
|
|
@ -620,7 +640,7 @@ async def index_onedrive_files(
|
|||
"Missing SECRET_KEY",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, 0, error_msg
|
||||
return 0, 0, error_msg, 0
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
|
|
@ -632,12 +652,13 @@ async def index_onedrive_files(
|
|||
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
total_unsupported = 0
|
||||
|
||||
# Index selected individual files
|
||||
selected_files = items_dict.get("files", [])
|
||||
if selected_files:
|
||||
file_tuples = [(f["id"], f.get("name")) for f in selected_files]
|
||||
indexed, skipped, _errors = await _index_selected_files(
|
||||
indexed, skipped, unsupported, _errors = await _index_selected_files(
|
||||
onedrive_client,
|
||||
session,
|
||||
file_tuples,
|
||||
|
|
@ -648,6 +669,7 @@ async def index_onedrive_files(
|
|||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
total_unsupported += unsupported
|
||||
|
||||
# Index selected folders
|
||||
folders = items_dict.get("folders", [])
|
||||
|
|
@ -661,7 +683,7 @@ async def index_onedrive_files(
|
|||
|
||||
if can_use_delta:
|
||||
logger.info(f"Using delta sync for folder {folder_name}")
|
||||
indexed, skipped, new_delta_link = await _index_with_delta_sync(
|
||||
indexed, skipped, unsup, new_delta_link = await _index_with_delta_sync(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
|
|
@ -676,6 +698,7 @@ async def index_onedrive_files(
|
|||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
total_unsupported += unsup
|
||||
|
||||
if new_delta_link:
|
||||
await session.refresh(connector)
|
||||
|
|
@ -685,7 +708,7 @@ async def index_onedrive_files(
|
|||
flag_modified(connector, "config")
|
||||
|
||||
# Reconciliation full scan
|
||||
ri, rs = await _index_full_scan(
|
||||
ri, rs, ru = await _index_full_scan(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
|
|
@ -701,9 +724,10 @@ async def index_onedrive_files(
|
|||
)
|
||||
total_indexed += ri
|
||||
total_skipped += rs
|
||||
total_unsupported += ru
|
||||
else:
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped = await _index_full_scan(
|
||||
indexed, skipped, unsup = await _index_full_scan(
|
||||
onedrive_client,
|
||||
session,
|
||||
connector_id,
|
||||
|
|
@ -719,6 +743,7 @@ async def index_onedrive_files(
|
|||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
total_unsupported += unsup
|
||||
|
||||
# Store new delta link for this folder
|
||||
_, new_delta_link, _ = await onedrive_client.get_delta(folder_id=folder_id)
|
||||
|
|
@ -737,12 +762,18 @@ async def index_onedrive_files(
|
|||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed OneDrive indexing for connector {connector_id}",
|
||||
{"files_processed": total_indexed, "files_skipped": total_skipped},
|
||||
{
|
||||
"files_processed": total_indexed,
|
||||
"files_skipped": total_skipped,
|
||||
"files_unsupported": total_unsupported,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"OneDrive indexing completed: {total_indexed} indexed, {total_skipped} skipped"
|
||||
f"OneDrive indexing completed: {total_indexed} indexed, "
|
||||
f"{total_skipped} skipped, {total_unsupported} unsupported"
|
||||
)
|
||||
return total_indexed, total_skipped, None
|
||||
|
||||
return total_indexed, total_skipped, None, total_unsupported
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
|
|
@ -753,7 +784,7 @@ async def index_onedrive_files(
|
|||
{"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
return 0, 0, f"Database error: {db_error!s}", 0
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
|
|
@ -763,4 +794,4 @@ async def index_onedrive_files(
|
|||
{"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index OneDrive files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index OneDrive files: {e!s}"
|
||||
return 0, 0, f"Failed to index OneDrive files: {e!s}", 0
|
||||
|
|
|
|||
|
|
@ -1,41 +1,17 @@
|
|||
"""
|
||||
Document processors module for background tasks.
|
||||
|
||||
This module provides a collection of document processors for different content types
|
||||
and sources. Each processor is responsible for handling a specific type of document
|
||||
processing task in the background.
|
||||
|
||||
Available processors:
|
||||
- Extension processor: Handle documents from browser extension
|
||||
- Markdown processor: Process markdown files
|
||||
- File processors: Handle files using different ETL services (Unstructured, LlamaCloud, Docling)
|
||||
- YouTube processor: Process YouTube videos and extract transcripts
|
||||
Content extraction is handled by ``app.etl_pipeline.EtlPipelineService``.
|
||||
This package keeps orchestration (save, notify, page-limit) and
|
||||
non-ETL processors (extension, markdown, youtube).
|
||||
"""
|
||||
|
||||
# Extension processor
|
||||
# File processors (backward-compatible re-exports from _save)
|
||||
from ._save import (
|
||||
add_received_file_document_using_docling,
|
||||
add_received_file_document_using_llamacloud,
|
||||
add_received_file_document_using_unstructured,
|
||||
)
|
||||
from .extension_processor import add_extension_received_document
|
||||
|
||||
# Markdown processor
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
||||
# YouTube processor
|
||||
from .youtube_processor import add_youtube_video_document
|
||||
|
||||
__all__ = [
|
||||
# Extension processing
|
||||
"add_extension_received_document",
|
||||
# File processing with different ETL services
|
||||
"add_received_file_document_using_docling",
|
||||
"add_received_file_document_using_llamacloud",
|
||||
"add_received_file_document_using_unstructured",
|
||||
# Markdown file processing
|
||||
"add_received_markdown_file_document",
|
||||
# YouTube video processing
|
||||
"add_youtube_video_document",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,74 +0,0 @@
|
|||
"""
|
||||
Constants for file document processing.
|
||||
|
||||
Centralizes file type classification, LlamaCloud retry configuration,
|
||||
and timeout calculation parameters.
|
||||
"""
|
||||
|
||||
import ssl
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File type classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MARKDOWN_EXTENSIONS = (".md", ".markdown", ".txt")
|
||||
AUDIO_EXTENSIONS = (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
|
||||
DIRECT_CONVERT_EXTENSIONS = (".csv", ".tsv", ".html", ".htm")
|
||||
|
||||
|
||||
class FileCategory(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
AUDIO = "audio"
|
||||
DIRECT_CONVERT = "direct_convert"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
def classify_file(filename: str) -> FileCategory:
|
||||
"""Classify a file by its extension into a processing category."""
|
||||
lower = filename.lower()
|
||||
if lower.endswith(MARKDOWN_EXTENSIONS):
|
||||
return FileCategory.MARKDOWN
|
||||
if lower.endswith(AUDIO_EXTENSIONS):
|
||||
return FileCategory.AUDIO
|
||||
if lower.endswith(DIRECT_CONVERT_EXTENSIONS):
|
||||
return FileCategory.DIRECT_CONVERT
|
||||
return FileCategory.DOCUMENT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LlamaCloud retry configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LLAMACLOUD_MAX_RETRIES = 5
|
||||
LLAMACLOUD_BASE_DELAY = 10 # seconds (exponential backoff base)
|
||||
LLAMACLOUD_MAX_DELAY = 120 # max delay between retries (2 minutes)
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
|
||||
ssl.SSLError,
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteError,
|
||||
httpx.WriteTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.LocalProtocolError,
|
||||
ConnectionError,
|
||||
ConnectionResetError,
|
||||
TimeoutError,
|
||||
OSError,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeout calculation constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
UPLOAD_BYTES_PER_SECOND_SLOW = (
|
||||
100 * 1024
|
||||
) # 100 KB/s (conservative for slow connections)
|
||||
MIN_UPLOAD_TIMEOUT = 120 # Minimum 2 minutes for any file
|
||||
MAX_UPLOAD_TIMEOUT = 1800 # Maximum 30 minutes for very large files
|
||||
BASE_JOB_TIMEOUT = 600 # 10 minutes base for job processing
|
||||
PER_PAGE_JOB_TIMEOUT = 60 # 1 minute per page for processing
|
||||
|
|
@ -4,8 +4,8 @@ Lossless file-to-markdown converters for text-based formats.
|
|||
These converters handle file types that can be faithfully represented as
|
||||
markdown without any external ETL/OCR service:
|
||||
|
||||
- CSV / TSV → markdown table (stdlib ``csv``)
|
||||
- HTML / HTM → markdown (``markdownify``)
|
||||
- CSV / TSV → markdown table (stdlib ``csv``)
|
||||
- HTML / HTM / XHTML → markdown (``markdownify``)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -73,6 +73,7 @@ _CONVERTER_MAP: dict[str, Callable[..., str]] = {
|
|||
".tsv": tsv_to_markdown,
|
||||
".html": html_to_markdown,
|
||||
".htm": html_to_markdown,
|
||||
".xhtml": html_to_markdown,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,209 +0,0 @@
|
|||
"""
|
||||
ETL parsing strategies for different document processing services.
|
||||
|
||||
Provides parse functions for Unstructured, LlamaCloud, and Docling, along with
|
||||
LlamaCloud retry logic and dynamic timeout calculations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from logging import ERROR, getLogger
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Log
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
||||
from ._constants import (
|
||||
LLAMACLOUD_BASE_DELAY,
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
LLAMACLOUD_MAX_RETRIES,
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS,
|
||||
PER_PAGE_JOB_TIMEOUT,
|
||||
)
|
||||
from ._helpers import calculate_job_timeout, calculate_upload_timeout
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LlamaCloud parsing with retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def parse_with_llamacloud_retry(
|
||||
file_path: str,
|
||||
estimated_pages: int,
|
||||
task_logger: TaskLoggingService | None = None,
|
||||
log_entry: Log | None = None,
|
||||
):
|
||||
"""
|
||||
Parse a file with LlamaCloud with retry logic for transient SSL/connection errors.
|
||||
|
||||
Uses dynamic timeout calculations based on file size and page count to handle
|
||||
very large files reliably.
|
||||
|
||||
Returns:
|
||||
LlamaParse result object
|
||||
|
||||
Raises:
|
||||
Exception: If all retries fail
|
||||
"""
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
file_size_bytes = os.path.getsize(file_path)
|
||||
file_size_mb = file_size_bytes / (1024 * 1024)
|
||||
|
||||
upload_timeout = calculate_upload_timeout(file_size_bytes)
|
||||
job_timeout = calculate_job_timeout(estimated_pages, file_size_bytes)
|
||||
|
||||
custom_timeout = httpx.Timeout(
|
||||
connect=120.0,
|
||||
read=upload_timeout,
|
||||
write=upload_timeout,
|
||||
pool=120.0,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
|
||||
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
|
||||
f"job_timeout={job_timeout:.0f}s"
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
attempt_errors: list[str] = []
|
||||
|
||||
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1,
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
max_timeout=int(max(2000, job_timeout + upload_timeout)),
|
||||
job_timeout_in_seconds=job_timeout,
|
||||
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
|
||||
custom_client=custom_client,
|
||||
)
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
if attempt > 1:
|
||||
logging.info(
|
||||
f"LlamaCloud upload succeeded on attempt {attempt} after "
|
||||
f"{len(attempt_errors)} failures"
|
||||
)
|
||||
return result
|
||||
|
||||
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
|
||||
last_exception = e
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)[:200]
|
||||
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
|
||||
|
||||
if attempt < LLAMACLOUD_MAX_RETRIES:
|
||||
base_delay = min(
|
||||
LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1)),
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
)
|
||||
jitter = base_delay * 0.25 * (2 * random.random() - 1)
|
||||
delay = base_delay + jitter
|
||||
|
||||
if task_logger and log_entry:
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"LlamaCloud upload failed "
|
||||
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}), "
|
||||
f"retrying in {delay:.0f}s",
|
||||
{
|
||||
"error_type": error_type,
|
||||
"error_message": error_msg,
|
||||
"attempt": attempt,
|
||||
"retry_delay": delay,
|
||||
"file_size_mb": round(file_size_mb, 1),
|
||||
"upload_timeout": upload_timeout,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"LlamaCloud upload failed "
|
||||
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): "
|
||||
f"{error_type}. File: {file_size_mb:.1f}MB. "
|
||||
f"Retrying in {delay:.0f}s..."
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logging.error(
|
||||
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} "
|
||||
f"attempts. File size: {file_size_mb:.1f}MB, "
|
||||
f"Pages: {estimated_pages}. "
|
||||
f"Errors: {'; '.join(attempt_errors)}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
raise last_exception or RuntimeError(
|
||||
f"LlamaCloud parsing failed after {LLAMACLOUD_MAX_RETRIES} retries. "
|
||||
f"File size: {file_size_mb:.1f}MB"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-service parse functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def parse_with_unstructured(file_path: str):
|
||||
"""
|
||||
Parse a file using the Unstructured ETL service.
|
||||
|
||||
Returns:
|
||||
List of LangChain Document elements.
|
||||
"""
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
return await loader.aload()
|
||||
|
||||
|
||||
async def parse_with_docling(file_path: str, filename: str) -> str:
|
||||
"""
|
||||
Parse a file using the Docling ETL service (via the Docling service wrapper).
|
||||
|
||||
Returns:
|
||||
Markdown content string.
|
||||
"""
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
|
||||
pdfminer_logger = getLogger("pdfminer")
|
||||
original_level = pdfminer_logger.level
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pdfminer")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".*Cannot set gray non-stroke color.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*invalid float value.*")
|
||||
pdfminer_logger.setLevel(ERROR)
|
||||
|
||||
try:
|
||||
result = await docling_service.process_document(file_path, filename)
|
||||
finally:
|
||||
pdfminer_logger.setLevel(original_level)
|
||||
|
||||
return result["content"]
|
||||
|
|
@ -11,13 +11,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
from ._constants import (
|
||||
BASE_JOB_TIMEOUT,
|
||||
MAX_UPLOAD_TIMEOUT,
|
||||
MIN_UPLOAD_TIMEOUT,
|
||||
PER_PAGE_JOB_TIMEOUT,
|
||||
UPLOAD_BYTES_PER_SECOND_SLOW,
|
||||
)
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
|
|
@ -198,21 +191,3 @@ async def update_document_from_connector(
|
|||
if "connector_id" in connector:
|
||||
document.connector_id = connector["connector_id"]
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeout calculations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_upload_timeout(file_size_bytes: int) -> float:
|
||||
"""Calculate upload timeout based on file size (conservative for slow connections)."""
|
||||
estimated_time = (file_size_bytes / UPLOAD_BYTES_PER_SECOND_SLOW) * 1.5
|
||||
return max(MIN_UPLOAD_TIMEOUT, min(estimated_time, MAX_UPLOAD_TIMEOUT))
|
||||
|
||||
|
||||
def calculate_job_timeout(estimated_pages: int, file_size_bytes: int) -> float:
|
||||
"""Calculate job processing timeout based on page count and file size."""
|
||||
page_based_timeout = BASE_JOB_TIMEOUT + (estimated_pages * PER_PAGE_JOB_TIMEOUT)
|
||||
size_based_timeout = BASE_JOB_TIMEOUT + (file_size_bytes / (10 * 1024 * 1024)) * 60
|
||||
return max(page_based_timeout, size_based_timeout)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,9 @@
|
|||
"""
|
||||
Unified document save/update logic for file processors.
|
||||
|
||||
Replaces the three nearly-identical ``add_received_file_document_using_*``
|
||||
functions with a single ``save_file_document`` function plus thin wrappers
|
||||
for backward compatibility.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -207,79 +202,3 @@ async def save_file_document(
|
|||
raise RuntimeError(
|
||||
f"Failed to process file document using {etl_service}: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward-compatible wrapper functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def add_received_file_document_using_unstructured(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
unstructured_processed_elements: list[LangChainDocument],
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store a file document using the Unstructured service."""
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
markdown_content = await convert_document_to_markdown(
|
||||
unstructured_processed_elements
|
||||
)
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
markdown_content,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"UNSTRUCTURED",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
|
||||
|
||||
async def add_received_file_document_using_llamacloud(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
llamacloud_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store document content parsed by LlamaCloud."""
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
llamacloud_markdown_document,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"LLAMACLOUD",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
|
||||
|
||||
async def add_received_file_document_using_docling(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
docling_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store document content parsed by Docling."""
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
docling_markdown_document,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"DOCLING",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,8 @@
|
|||
"""
|
||||
File document processors orchestrating content extraction and indexing.
|
||||
|
||||
This module is the public entry point for file processing. It delegates to
|
||||
specialised sub-modules that each own a single concern:
|
||||
|
||||
- ``_constants`` — file type classification and configuration constants
|
||||
- ``_helpers`` — document deduplication, migration, connector helpers
|
||||
- ``_direct_converters`` — lossless file-to-markdown for csv/tsv/html
|
||||
- ``_etl`` — ETL parsing strategies (Unstructured, LlamaCloud, Docling)
|
||||
- ``_save`` — unified document creation / update logic
|
||||
Delegates content extraction to ``app.etl_pipeline.EtlPipelineService`` and
|
||||
keeps only orchestration concerns (notifications, logging, page limits, saving).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -17,38 +11,19 @@ import contextlib
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from logging import ERROR, getLogger
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Document, Log, Notification
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
||||
from ._constants import FileCategory, classify_file
|
||||
from ._direct_converters import convert_file_directly
|
||||
from ._etl import (
|
||||
parse_with_docling,
|
||||
parse_with_llamacloud_retry,
|
||||
parse_with_unstructured,
|
||||
)
|
||||
from ._helpers import update_document_from_connector
|
||||
from ._save import (
|
||||
add_received_file_document_using_docling,
|
||||
add_received_file_document_using_llamacloud,
|
||||
add_received_file_document_using_unstructured,
|
||||
save_file_document,
|
||||
)
|
||||
from ._save import save_file_document
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
||||
# Re-export public API so existing ``from file_processors import …`` keeps working.
|
||||
__all__ = [
|
||||
"add_received_file_document_using_docling",
|
||||
"add_received_file_document_using_llamacloud",
|
||||
"add_received_file_document_using_unstructured",
|
||||
"parse_with_llamacloud_retry",
|
||||
"process_file_in_background",
|
||||
"process_file_in_background_with_document",
|
||||
"save_file_document",
|
||||
|
|
@ -142,35 +117,31 @@ async def _log_page_divergence(
|
|||
# ===================================================================
|
||||
|
||||
|
||||
async def _process_markdown_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Read a markdown / text file and create or update a document."""
|
||||
await _notify(ctx, "parsing", "Reading file")
|
||||
async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Extract content from a non-document file (plaintext/direct_convert/audio) via the unified ETL pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
await _notify(ctx, "parsing", "Processing file")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing markdown/text file: {ctx.filename}",
|
||||
{"file_type": "markdown", "processing_stage": "reading_file"},
|
||||
f"Processing file: {ctx.filename}",
|
||||
{"processing_stage": "extracting"},
|
||||
)
|
||||
|
||||
with open(ctx.file_path, encoding="utf-8") as f:
|
||||
markdown_content = f.read()
|
||||
etl_result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=ctx.file_path, filename=ctx.filename)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Creating document from markdown content: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "creating_document",
|
||||
"content_length": len(markdown_content),
|
||||
},
|
||||
)
|
||||
|
||||
result = await add_received_markdown_file_document(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
markdown_content,
|
||||
etl_result.markdown_content,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
ctx.connector,
|
||||
|
|
@ -181,179 +152,19 @@ async def _process_markdown_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
if result:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed markdown file: {ctx.filename}",
|
||||
f"Successfully processed file: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "markdown",
|
||||
"file_type": etl_result.content_type,
|
||||
"etl_service": etl_result.etl_service,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Markdown file already exists (duplicate): {ctx.filename}",
|
||||
{"duplicate_detected": True, "file_type": "markdown"},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _process_direct_convert_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Convert a text-based file (csv/tsv/html) to markdown without ETL."""
|
||||
await _notify(ctx, "parsing", "Converting file")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Direct-converting file to markdown: {ctx.filename}",
|
||||
{"file_type": "direct_convert", "processing_stage": "converting"},
|
||||
)
|
||||
|
||||
markdown_content = convert_file_directly(ctx.file_path, ctx.filename)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Creating document from converted content: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "creating_document",
|
||||
"content_length": len(markdown_content),
|
||||
},
|
||||
)
|
||||
|
||||
result = await add_received_markdown_file_document(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
markdown_content,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
ctx.connector,
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
|
||||
if result:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully direct-converted file: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "direct_convert",
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Direct-converted file already exists (duplicate): {ctx.filename}",
|
||||
{"duplicate_detected": True, "file_type": "direct_convert"},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _process_audio_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Transcribe an audio file and create or update a document."""
|
||||
await _notify(ctx, "parsing", "Transcribing audio")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing audio file for transcription: {ctx.filename}",
|
||||
{"file_type": "audio", "processing_stage": "starting_transcription"},
|
||||
)
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
try:
|
||||
stt_result = stt_service.transcribe_file(ctx.file_path)
|
||||
transcribed_text = stt_result.get("text", "")
|
||||
if not transcribed_text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
transcribed_text = (
|
||||
f"# Transcription of {ctx.filename}\n\n{transcribed_text}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to transcribe audio file {ctx.filename}: {e!s}",
|
||||
) from e
|
||||
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Local STT transcription completed: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "local_transcription_complete",
|
||||
"language": stt_result.get("language"),
|
||||
"confidence": stt_result.get("language_probability"),
|
||||
"duration": stt_result.get("duration"),
|
||||
},
|
||||
)
|
||||
else:
|
||||
from litellm import atranscription
|
||||
|
||||
with open(ctx.file_path, "rb") as audio_file:
|
||||
transcription_kwargs: dict = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
transcription_kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
|
||||
transcription_response = await atranscription(**transcription_kwargs)
|
||||
transcribed_text = transcription_response.get("text", "")
|
||||
if not transcribed_text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
|
||||
transcribed_text = f"# Transcription of {ctx.filename}\n\n{transcribed_text}"
|
||||
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Transcription completed, creating document: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "transcription_complete",
|
||||
"transcript_length": len(transcribed_text),
|
||||
},
|
||||
)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
result = await add_received_markdown_file_document(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
transcribed_text,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
ctx.connector,
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
|
||||
if result:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully transcribed and processed audio file: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "audio",
|
||||
"transcript_length": len(transcribed_text),
|
||||
"stt_service": stt_service_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Audio file transcript already exists (duplicate): {ctx.filename}",
|
||||
{"duplicate_detected": True, "file_type": "audio"},
|
||||
f"File already exists (duplicate): {ctx.filename}",
|
||||
{"duplicate_detected": True, "file_type": etl_result.content_type},
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
@ -363,279 +174,10 @@ async def _process_audio_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _etl_unstructured(
|
||||
ctx: _ProcessingContext,
|
||||
page_limit_service,
|
||||
estimated_pages: int,
|
||||
) -> Document | None:
|
||||
"""Parse and save via the Unstructured ETL service."""
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing file with Unstructured ETL: {ctx.filename}",
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
"processing_stage": "loading",
|
||||
},
|
||||
)
|
||||
|
||||
docs = await parse_with_unstructured(ctx.file_path)
|
||||
|
||||
await _notify(ctx, "chunking", chunks_count=len(docs))
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Unstructured ETL completed, creating document: {ctx.filename}",
|
||||
{"processing_stage": "etl_complete", "elements_count": len(docs)},
|
||||
)
|
||||
|
||||
actual_pages = page_limit_service.estimate_pages_from_elements(docs)
|
||||
final_pages = max(estimated_pages, actual_pages)
|
||||
await _log_page_divergence(
|
||||
ctx.task_logger,
|
||||
ctx.log_entry,
|
||||
ctx.filename,
|
||||
estimated_pages,
|
||||
actual_pages,
|
||||
final_pages,
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
result = await add_received_file_document_using_unstructured(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
docs,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
ctx.connector,
|
||||
enable_summary=ctx.enable_summary,
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
|
||||
if result:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, final_pages, allow_exceed=True
|
||||
)
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed file with Unstructured: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
"pages_processed": final_pages,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Document already exists (duplicate): {ctx.filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "UNSTRUCTURED",
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _etl_llamacloud(
|
||||
ctx: _ProcessingContext,
|
||||
page_limit_service,
|
||||
estimated_pages: int,
|
||||
) -> Document | None:
|
||||
"""Parse and save via the LlamaCloud ETL service."""
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing file with LlamaCloud ETL: {ctx.filename}",
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"processing_stage": "parsing",
|
||||
"estimated_pages": estimated_pages,
|
||||
},
|
||||
)
|
||||
|
||||
raw_result = await parse_with_llamacloud_retry(
|
||||
file_path=ctx.file_path,
|
||||
estimated_pages=estimated_pages,
|
||||
task_logger=ctx.task_logger,
|
||||
log_entry=ctx.log_entry,
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
markdown_documents = await raw_result.aget_markdown_documents(split_by_page=False)
|
||||
|
||||
await _notify(ctx, "chunking", chunks_count=len(markdown_documents))
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"LlamaCloud parsing completed, creating documents: {ctx.filename}",
|
||||
{
|
||||
"processing_stage": "parsing_complete",
|
||||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
|
||||
if not markdown_documents:
|
||||
await ctx.task_logger.log_task_failure(
|
||||
ctx.log_entry,
|
||||
f"LlamaCloud parsing returned no documents: {ctx.filename}",
|
||||
"ETL service returned empty document list",
|
||||
{"error_type": "EmptyDocumentList", "etl_service": "LLAMACLOUD"},
|
||||
)
|
||||
raise ValueError(f"LlamaCloud parsing returned no documents for {ctx.filename}")
|
||||
|
||||
actual_pages = page_limit_service.estimate_pages_from_markdown(markdown_documents)
|
||||
final_pages = max(estimated_pages, actual_pages)
|
||||
await _log_page_divergence(
|
||||
ctx.task_logger,
|
||||
ctx.log_entry,
|
||||
ctx.filename,
|
||||
estimated_pages,
|
||||
actual_pages,
|
||||
final_pages,
|
||||
)
|
||||
|
||||
any_created = False
|
||||
last_doc: Document | None = None
|
||||
|
||||
for doc in markdown_documents:
|
||||
doc_result = await add_received_file_document_using_llamacloud(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
llamacloud_markdown_document=doc.text,
|
||||
search_space_id=ctx.search_space_id,
|
||||
user_id=ctx.user_id,
|
||||
connector=ctx.connector,
|
||||
enable_summary=ctx.enable_summary,
|
||||
)
|
||||
if doc_result:
|
||||
any_created = True
|
||||
last_doc = doc_result
|
||||
|
||||
if any_created:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, final_pages, allow_exceed=True
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(last_doc, ctx.connector, ctx.session)
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed file with LlamaCloud: {ctx.filename}",
|
||||
{
|
||||
"document_id": last_doc.id,
|
||||
"content_hash": last_doc.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"pages_processed": final_pages,
|
||||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
return last_doc
|
||||
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Document already exists (duplicate): {ctx.filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "LLAMACLOUD",
|
||||
"documents_count": len(markdown_documents),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _etl_docling(
|
||||
ctx: _ProcessingContext,
|
||||
page_limit_service,
|
||||
estimated_pages: int,
|
||||
) -> Document | None:
|
||||
"""Parse and save via the Docling ETL service."""
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Processing file with Docling ETL: {ctx.filename}",
|
||||
{
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
"processing_stage": "parsing",
|
||||
},
|
||||
)
|
||||
|
||||
content = await parse_with_docling(ctx.file_path, ctx.filename)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
await ctx.task_logger.log_task_progress(
|
||||
ctx.log_entry,
|
||||
f"Docling parsing completed, creating document: {ctx.filename}",
|
||||
{"processing_stage": "parsing_complete", "content_length": len(content)},
|
||||
)
|
||||
|
||||
actual_pages = page_limit_service.estimate_pages_from_content_length(len(content))
|
||||
final_pages = max(estimated_pages, actual_pages)
|
||||
await _log_page_divergence(
|
||||
ctx.task_logger,
|
||||
ctx.log_entry,
|
||||
ctx.filename,
|
||||
estimated_pages,
|
||||
actual_pages,
|
||||
final_pages,
|
||||
)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
|
||||
result = await add_received_file_document_using_docling(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
docling_markdown_document=content,
|
||||
search_space_id=ctx.search_space_id,
|
||||
user_id=ctx.user_id,
|
||||
connector=ctx.connector,
|
||||
enable_summary=ctx.enable_summary,
|
||||
)
|
||||
|
||||
if result:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, final_pages, allow_exceed=True
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed file with Docling: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
"pages_processed": final_pages,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Document already exists (duplicate): {ctx.filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": "DOCLING",
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
||||
"""Route a document file to the configured ETL service."""
|
||||
"""Route a document file to the configured ETL service via the unified pipeline."""
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.services.page_limit_service import PageLimitExceededError, PageLimitService
|
||||
|
||||
page_limit_service = PageLimitService(ctx.session)
|
||||
|
|
@ -665,16 +207,60 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None:
|
|||
os.unlink(ctx.file_path)
|
||||
raise HTTPException(status_code=403, detail=str(e)) from e
|
||||
|
||||
etl_dispatch = {
|
||||
"UNSTRUCTURED": _etl_unstructured,
|
||||
"LLAMACLOUD": _etl_llamacloud,
|
||||
"DOCLING": _etl_docling,
|
||||
}
|
||||
handler = etl_dispatch.get(app_config.ETL_SERVICE)
|
||||
if handler is None:
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
await _notify(ctx, "parsing", "Extracting content")
|
||||
|
||||
return await handler(ctx, page_limit_service, estimated_pages)
|
||||
etl_result = await EtlPipelineService().extract(
|
||||
EtlRequest(
|
||||
file_path=ctx.file_path,
|
||||
filename=ctx.filename,
|
||||
estimated_pages=estimated_pages,
|
||||
)
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(ctx.file_path)
|
||||
|
||||
await _notify(ctx, "chunking")
|
||||
|
||||
result = await save_file_document(
|
||||
ctx.session,
|
||||
ctx.filename,
|
||||
etl_result.markdown_content,
|
||||
ctx.search_space_id,
|
||||
ctx.user_id,
|
||||
etl_result.etl_service,
|
||||
ctx.connector,
|
||||
enable_summary=ctx.enable_summary,
|
||||
)
|
||||
|
||||
if result:
|
||||
await page_limit_service.update_page_usage(
|
||||
ctx.user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
if ctx.connector:
|
||||
await update_document_from_connector(result, ctx.connector, ctx.session)
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Successfully processed file: {ctx.filename}",
|
||||
{
|
||||
"document_id": result.id,
|
||||
"content_hash": result.content_hash,
|
||||
"file_type": "document",
|
||||
"etl_service": etl_result.etl_service,
|
||||
"pages_processed": estimated_pages,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await ctx.task_logger.log_task_success(
|
||||
ctx.log_entry,
|
||||
f"Document already exists (duplicate): {ctx.filename}",
|
||||
{
|
||||
"duplicate_detected": True,
|
||||
"file_type": "document",
|
||||
"etl_service": etl_result.etl_service,
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ===================================================================
|
||||
|
|
@ -706,15 +292,16 @@ async def process_file_in_background(
|
|||
)
|
||||
|
||||
try:
|
||||
category = classify_file(filename)
|
||||
from app.etl_pipeline.file_classifier import (
|
||||
FileCategory as EtlFileCategory,
|
||||
classify_file as etl_classify,
|
||||
)
|
||||
|
||||
if category == FileCategory.MARKDOWN:
|
||||
return await _process_markdown_upload(ctx)
|
||||
if category == FileCategory.DIRECT_CONVERT:
|
||||
return await _process_direct_convert_upload(ctx)
|
||||
if category == FileCategory.AUDIO:
|
||||
return await _process_audio_upload(ctx)
|
||||
return await _process_document_upload(ctx)
|
||||
category = etl_classify(filename)
|
||||
|
||||
if category == EtlFileCategory.DOCUMENT:
|
||||
return await _process_document_upload(ctx)
|
||||
return await _process_non_document_upload(ctx)
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
|
|
@ -758,201 +345,64 @@ async def _extract_file_content(
|
|||
Returns:
|
||||
Tuple of (markdown_content, etl_service_name).
|
||||
"""
|
||||
category = classify_file(filename)
|
||||
|
||||
if category == FileCategory.MARKDOWN:
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Reading file",
|
||||
)
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing markdown/text file: {filename}",
|
||||
{"file_type": "markdown", "processing_stage": "reading_file"},
|
||||
)
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
return content, "MARKDOWN"
|
||||
|
||||
if category == FileCategory.DIRECT_CONVERT:
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Converting file",
|
||||
)
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Direct-converting file to markdown: {filename}",
|
||||
{"file_type": "direct_convert", "processing_stage": "converting"},
|
||||
)
|
||||
content = convert_file_directly(file_path, filename)
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
return content, "DIRECT_CONVERT"
|
||||
|
||||
if category == FileCategory.AUDIO:
|
||||
if notification:
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Transcribing audio",
|
||||
)
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing audio file for transcription: {filename}",
|
||||
{"file_type": "audio", "processing_stage": "starting_transcription"},
|
||||
)
|
||||
transcribed_text = await _transcribe_audio(file_path, filename)
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
return transcribed_text, "AUDIO_TRANSCRIPTION"
|
||||
|
||||
# Document file — use ETL service
|
||||
return await _extract_document_content(
|
||||
file_path,
|
||||
filename,
|
||||
session,
|
||||
user_id,
|
||||
task_logger,
|
||||
log_entry,
|
||||
notification,
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
from app.etl_pipeline.file_classifier import (
|
||||
FileCategory,
|
||||
classify_file as etl_classify,
|
||||
)
|
||||
|
||||
|
||||
async def _transcribe_audio(file_path: str, filename: str) -> str:
|
||||
"""Transcribe an audio file and return formatted markdown text."""
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
result = stt_service.transcribe_file(file_path)
|
||||
text = result.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
else:
|
||||
from litellm import atranscription
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
response = await atranscription(**kwargs)
|
||||
text = response.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
|
||||
async def _extract_document_content(
|
||||
file_path: str,
|
||||
filename: str,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: Log,
|
||||
notification: Notification | None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Parse a document file via the configured ETL service.
|
||||
|
||||
Returns:
|
||||
Tuple of (markdown_content, etl_service_name).
|
||||
"""
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
|
||||
page_limit_service = PageLimitService(session)
|
||||
|
||||
try:
|
||||
estimated_pages = page_limit_service.estimate_pages_before_processing(file_path)
|
||||
except Exception:
|
||||
file_size = os.path.getsize(file_path)
|
||||
estimated_pages = max(1, file_size // (80 * 1024))
|
||||
|
||||
await page_limit_service.check_page_limit(user_id, estimated_pages)
|
||||
|
||||
etl_service = app_config.ETL_SERVICE
|
||||
markdown_content: str | None = None
|
||||
category = etl_classify(filename)
|
||||
estimated_pages = 0
|
||||
|
||||
if notification:
|
||||
stage_messages = {
|
||||
FileCategory.PLAINTEXT: "Reading file",
|
||||
FileCategory.DIRECT_CONVERT: "Converting file",
|
||||
FileCategory.AUDIO: "Transcribing audio",
|
||||
FileCategory.UNSUPPORTED: "Unsupported file type",
|
||||
FileCategory.DOCUMENT: "Extracting content",
|
||||
}
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session,
|
||||
notification,
|
||||
stage="parsing",
|
||||
stage_message="Extracting content",
|
||||
stage_message=stage_messages.get(category, "Processing"),
|
||||
)
|
||||
|
||||
if etl_service == "UNSTRUCTURED":
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Processing {category.value} file: {filename}",
|
||||
{"file_type": category.value, "processing_stage": "extracting"},
|
||||
)
|
||||
|
||||
docs = await parse_with_unstructured(file_path)
|
||||
markdown_content = await convert_document_to_markdown(docs)
|
||||
actual_pages = page_limit_service.estimate_pages_from_elements(docs)
|
||||
final_pages = max(estimated_pages, actual_pages)
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, final_pages, allow_exceed=True
|
||||
)
|
||||
if category == FileCategory.DOCUMENT:
|
||||
from app.services.page_limit_service import PageLimitService
|
||||
|
||||
elif etl_service == "LLAMACLOUD":
|
||||
raw_result = await parse_with_llamacloud_retry(
|
||||
page_limit_service = PageLimitService(session)
|
||||
estimated_pages = _estimate_pages_safe(page_limit_service, file_path)
|
||||
await page_limit_service.check_page_limit(user_id, estimated_pages)
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(
|
||||
file_path=file_path,
|
||||
filename=filename,
|
||||
estimated_pages=estimated_pages,
|
||||
task_logger=task_logger,
|
||||
log_entry=log_entry,
|
||||
)
|
||||
markdown_documents = await raw_result.aget_markdown_documents(
|
||||
split_by_page=False
|
||||
)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud parsing returned no documents: {filename}")
|
||||
markdown_content = markdown_documents[0].text
|
||||
)
|
||||
|
||||
if category == FileCategory.DOCUMENT:
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
|
||||
elif etl_service == "DOCLING":
|
||||
getLogger("docling.pipeline.base_pipeline").setLevel(ERROR)
|
||||
getLogger("docling.document_converter").setLevel(ERROR)
|
||||
getLogger("docling_core.transforms.chunker.hierarchical_chunker").setLevel(
|
||||
ERROR
|
||||
)
|
||||
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
markdown_content = result.document.export_to_markdown()
|
||||
await page_limit_service.update_page_usage(
|
||||
user_id, estimated_pages, allow_exceed=True
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {etl_service}")
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
os.unlink(file_path)
|
||||
|
||||
if not markdown_content:
|
||||
if not result.markdown_content:
|
||||
raise RuntimeError(f"Failed to extract content from file: {filename}")
|
||||
|
||||
return markdown_content, etl_service
|
||||
return result.markdown_content, result.etl_service
|
||||
|
||||
|
||||
async def process_file_in_background_with_document(
|
||||
|
|
|
|||
124
surfsense_backend/app/utils/file_extensions.py
Normal file
124
surfsense_backend/app/utils/file_extensions.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Per-parser document extension sets for the ETL pipeline.
|
||||
|
||||
Every consumer (file_classifier, connector-level skip checks, ETL pipeline
|
||||
validation) imports from here so there is a single source of truth.
|
||||
|
||||
Extensions already covered by PLAINTEXT_EXTENSIONS, AUDIO_EXTENSIONS, or
|
||||
DIRECT_CONVERT_EXTENSIONS in file_classifier are NOT repeated here -- these
|
||||
sets are exclusively for the "document" ETL path (Docling / LlamaParse /
|
||||
Unstructured).
|
||||
"""
|
||||
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-parser document extension sets (from official documentation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCLING_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
".pdf",
|
||||
".docx",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".tiff",
|
||||
".tif",
|
||||
".bmp",
|
||||
".webp",
|
||||
}
|
||||
)
|
||||
|
||||
LLAMAPARSE_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
".pdf",
|
||||
".docx",
|
||||
".doc",
|
||||
".xlsx",
|
||||
".xls",
|
||||
".pptx",
|
||||
".ppt",
|
||||
".docm",
|
||||
".dot",
|
||||
".dotm",
|
||||
".pptm",
|
||||
".pot",
|
||||
".potx",
|
||||
".xlsm",
|
||||
".xlsb",
|
||||
".xlw",
|
||||
".rtf",
|
||||
".epub",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".tif",
|
||||
".webp",
|
||||
".svg",
|
||||
".odt",
|
||||
".ods",
|
||||
".odp",
|
||||
".hwp",
|
||||
".hwpx",
|
||||
}
|
||||
)
|
||||
|
||||
UNSTRUCTURED_DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
".pdf",
|
||||
".docx",
|
||||
".doc",
|
||||
".xlsx",
|
||||
".xls",
|
||||
".pptx",
|
||||
".ppt",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".tif",
|
||||
".heic",
|
||||
".rtf",
|
||||
".epub",
|
||||
".odt",
|
||||
".eml",
|
||||
".msg",
|
||||
".p7s",
|
||||
}
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Union (used by classify_file for routing) + service lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCUMENT_EXTENSIONS: frozenset[str] = (
|
||||
DOCLING_DOCUMENT_EXTENSIONS
|
||||
| LLAMAPARSE_DOCUMENT_EXTENSIONS
|
||||
| UNSTRUCTURED_DOCUMENT_EXTENSIONS
|
||||
)
|
||||
|
||||
_SERVICE_MAP: dict[str, frozenset[str]] = {
|
||||
"DOCLING": DOCLING_DOCUMENT_EXTENSIONS,
|
||||
"LLAMACLOUD": LLAMAPARSE_DOCUMENT_EXTENSIONS,
|
||||
"UNSTRUCTURED": UNSTRUCTURED_DOCUMENT_EXTENSIONS,
|
||||
}
|
||||
|
||||
|
||||
def get_document_extensions_for_service(etl_service: str | None) -> frozenset[str]:
|
||||
"""Return the document extensions supported by *etl_service*.
|
||||
|
||||
Falls back to the full union when the service is ``None`` or unknown.
|
||||
"""
|
||||
return _SERVICE_MAP.get(etl_service or "", DOCUMENT_EXTENSIONS)
|
||||
|
||||
|
||||
def is_supported_document_extension(filename: str) -> bool:
|
||||
"""Return True if the file's extension is in the supported document set."""
|
||||
suffix = PurePosixPath(filename).suffix.lower()
|
||||
return suffix in DOCUMENT_EXTENSIONS
|
||||
|
|
@ -319,31 +319,23 @@ def _mock_etl_parsing(monkeypatch):
|
|||
|
||||
# -- LlamaParse mock (external API) --------------------------------
|
||||
|
||||
class _FakeMarkdownDoc:
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
class _FakeLlamaParseResult:
|
||||
async def aget_markdown_documents(self, *, split_by_page=False):
|
||||
return [_FakeMarkdownDoc(_MOCK_ETL_MARKDOWN)]
|
||||
|
||||
async def _fake_llamacloud_parse(**kwargs):
|
||||
_reject_empty(kwargs["file_path"])
|
||||
return _FakeLlamaParseResult()
|
||||
async def _fake_llamacloud_parse(file_path: str, estimated_pages: int) -> str:
|
||||
_reject_empty(file_path)
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.document_processors.file_processors.parse_with_llamacloud_retry",
|
||||
"app.etl_pipeline.parsers.llamacloud.parse_with_llamacloud",
|
||||
_fake_llamacloud_parse,
|
||||
)
|
||||
|
||||
# -- Docling mock (heavy library boundary) -------------------------
|
||||
|
||||
async def _fake_docling_parse(file_path: str, filename: str):
|
||||
async def _fake_docling_parse(file_path: str, filename: str) -> str:
|
||||
_reject_empty(file_path)
|
||||
return _MOCK_ETL_MARKDOWN
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.tasks.document_processors.file_processors.parse_with_docling",
|
||||
"app.etl_pipeline.parsers.docling.parse_with_docling",
|
||||
_fake_docling_parse,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ async def test_composio_connector_without_account_id_returns_error(
|
|||
|
||||
maker = make_session_factory(async_engine)
|
||||
async with maker() as session:
|
||||
count, _skipped, error = await index_google_drive_files(
|
||||
count, _skipped, error, _unsupported = await index_google_drive_files(
|
||||
session=session,
|
||||
connector_id=data["connector_id"],
|
||||
search_space_id=data["search_space_id"],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,244 @@
|
|||
"""Tests that each cloud connector's download_and_extract_content correctly
|
||||
produces markdown from a real file via the unified ETL pipeline.
|
||||
|
||||
Only the cloud client is mocked (system boundary). The ETL pipeline runs for
|
||||
real so we know the full path from "cloud gives us bytes" to "we get markdown
|
||||
back" actually works.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_TXT_CONTENT = "Hello from the cloud connector test."
|
||||
_CSV_CONTENT = "name,age\nAlice,30\nBob,25\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _write_file(dest_path: str, content: str) -> None:
|
||||
"""Simulate a cloud client writing downloaded bytes to disk."""
|
||||
with open(dest_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def _make_download_side_effect(content: str):
|
||||
"""Return an async side-effect that writes *content* to the dest path
|
||||
and returns ``None`` (success)."""
|
||||
|
||||
async def _side_effect(*args):
|
||||
dest_path = args[-1]
|
||||
await _write_file(dest_path, content)
|
||||
return None
|
||||
|
||||
return _side_effect
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Google Drive
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestGoogleDriveContentExtraction:
|
||||
async def test_txt_file_returns_markdown(self):
|
||||
from app.connectors.google_drive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_TXT_CONTENT),
|
||||
)
|
||||
|
||||
file = {"id": "f1", "name": "notes.txt", "mimeType": "text/plain"}
|
||||
|
||||
markdown, metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert _TXT_CONTENT in markdown
|
||||
assert metadata["google_drive_file_id"] == "f1"
|
||||
assert metadata["google_drive_file_name"] == "notes.txt"
|
||||
|
||||
async def test_csv_file_returns_markdown_table(self):
|
||||
from app.connectors.google_drive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_CSV_CONTENT),
|
||||
)
|
||||
|
||||
file = {"id": "f2", "name": "data.csv", "mimeType": "text/csv"}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert "Alice" in markdown
|
||||
assert "Bob" in markdown
|
||||
assert "|" in markdown
|
||||
|
||||
async def test_download_error_returns_error_message(self):
|
||||
from app.connectors.google_drive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(return_value="Network timeout")
|
||||
|
||||
file = {"id": "f3", "name": "doc.txt", "mimeType": "text/plain"}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert markdown is None
|
||||
assert error == "Network timeout"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# OneDrive
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestOneDriveContentExtraction:
|
||||
async def test_txt_file_returns_markdown(self):
|
||||
from app.connectors.onedrive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_TXT_CONTENT),
|
||||
)
|
||||
|
||||
file = {
|
||||
"id": "od-1",
|
||||
"name": "report.txt",
|
||||
"file": {"mimeType": "text/plain"},
|
||||
}
|
||||
|
||||
markdown, metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert _TXT_CONTENT in markdown
|
||||
assert metadata["onedrive_file_id"] == "od-1"
|
||||
assert metadata["onedrive_file_name"] == "report.txt"
|
||||
|
||||
async def test_csv_file_returns_markdown_table(self):
|
||||
from app.connectors.onedrive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_CSV_CONTENT),
|
||||
)
|
||||
|
||||
file = {
|
||||
"id": "od-2",
|
||||
"name": "data.csv",
|
||||
"file": {"mimeType": "text/csv"},
|
||||
}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert "Alice" in markdown
|
||||
assert "|" in markdown
|
||||
|
||||
async def test_download_error_returns_error_message(self):
|
||||
from app.connectors.onedrive.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(return_value="403 Forbidden")
|
||||
|
||||
file = {
|
||||
"id": "od-3",
|
||||
"name": "secret.txt",
|
||||
"file": {"mimeType": "text/plain"},
|
||||
}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert markdown is None
|
||||
assert error == "403 Forbidden"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Dropbox
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestDropboxContentExtraction:
|
||||
async def test_txt_file_returns_markdown(self):
|
||||
from app.connectors.dropbox.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_TXT_CONTENT),
|
||||
)
|
||||
|
||||
file = {
|
||||
"id": "dbx-1",
|
||||
"name": "memo.txt",
|
||||
".tag": "file",
|
||||
"path_lower": "/memo.txt",
|
||||
}
|
||||
|
||||
markdown, metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert _TXT_CONTENT in markdown
|
||||
assert metadata["dropbox_file_id"] == "dbx-1"
|
||||
assert metadata["dropbox_file_name"] == "memo.txt"
|
||||
|
||||
async def test_csv_file_returns_markdown_table(self):
|
||||
from app.connectors.dropbox.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(
|
||||
side_effect=_make_download_side_effect(_CSV_CONTENT),
|
||||
)
|
||||
|
||||
file = {
|
||||
"id": "dbx-2",
|
||||
"name": "data.csv",
|
||||
".tag": "file",
|
||||
"path_lower": "/data.csv",
|
||||
}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert error is None
|
||||
assert "Alice" in markdown
|
||||
assert "|" in markdown
|
||||
|
||||
async def test_download_error_returns_error_message(self):
|
||||
from app.connectors.dropbox.content_extractor import (
|
||||
download_and_extract_content,
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.download_file_to_disk = AsyncMock(return_value="Rate limited")
|
||||
|
||||
file = {
|
||||
"id": "dbx-3",
|
||||
"name": "big.txt",
|
||||
".tag": "file",
|
||||
"path_lower": "/big.txt",
|
||||
}
|
||||
|
||||
markdown, _metadata, error = await download_and_extract_content(client, file)
|
||||
|
||||
assert markdown is None
|
||||
assert error == "Rate limited"
|
||||
|
|
@ -8,6 +8,10 @@ import pytest
|
|||
from app.db import DocumentType
|
||||
from app.tasks.connector_indexers.dropbox_indexer import (
|
||||
_download_files_parallel,
|
||||
_index_full_scan,
|
||||
_index_selected_files,
|
||||
_index_with_delta_sync,
|
||||
index_dropbox_files,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
|
@ -234,3 +238,610 @@ async def test_heartbeat_fires_during_parallel_downloads(
|
|||
assert len(docs) == 3
|
||||
assert failed == 0
|
||||
assert len(heartbeat_calls) >= 1, "Heartbeat should have fired at least once"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# D1-D2: _index_full_scan tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _folder_dict(name: str) -> dict:
|
||||
return {".tag": "folder", "name": name}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_scan_mocks(mock_dropbox_client, monkeypatch):
|
||||
"""Wire up mocks for _index_full_scan in isolation."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_log_entry = MagicMock()
|
||||
|
||||
skip_results: dict[str, tuple[bool, str | None]] = {}
|
||||
|
||||
monkeypatch.setattr("app.config.config.ETL_SERVICE", "LLAMACLOUD")
|
||||
|
||||
async def _fake_skip(session, file, search_space_id):
|
||||
from app.connectors.dropbox.file_types import should_skip_file as _skip
|
||||
|
||||
item_skip, unsup_ext = _skip(file)
|
||||
if item_skip:
|
||||
if unsup_ext:
|
||||
return True, f"unsupported:{unsup_ext}"
|
||||
return True, "folder/non-downloadable"
|
||||
return skip_results.get(file.get("id", ""), (False, None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
from app.services.page_limit_service import PageLimitService as _RealPLS
|
||||
|
||||
mock_page_limit_instance = MagicMock()
|
||||
mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999))
|
||||
mock_page_limit_instance.update_page_usage = AsyncMock()
|
||||
|
||||
class _MockPageLimitService:
|
||||
estimate_pages_from_metadata = staticmethod(
|
||||
_RealPLS.estimate_pages_from_metadata
|
||||
)
|
||||
|
||||
def __init__(self, session):
|
||||
self.get_page_usage = mock_page_limit_instance.get_page_usage
|
||||
self.update_page_usage = mock_page_limit_instance.update_page_usage
|
||||
|
||||
monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService)
|
||||
|
||||
return {
|
||||
"dropbox_client": mock_dropbox_client,
|
||||
"session": mock_session,
|
||||
"task_logger": mock_task_logger,
|
||||
"log_entry": mock_log_entry,
|
||||
"skip_results": skip_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_full_scan(mocks, monkeypatch, page_files, *, max_files=500):
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod,
|
||||
"get_files_in_folder",
|
||||
AsyncMock(return_value=(page_files, None)),
|
||||
)
|
||||
return await _index_full_scan(
|
||||
mocks["dropbox_client"],
|
||||
mocks["session"],
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"",
|
||||
"Root",
|
||||
mocks["task_logger"],
|
||||
mocks["log_entry"],
|
||||
max_files,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
|
||||
"""Skipped files excluded, renames counted as indexed, new files downloaded."""
|
||||
page_files = [
|
||||
_folder_dict("SubFolder"),
|
||||
_make_file_dict("skip1", "unchanged.txt"),
|
||||
_make_file_dict("rename1", "renamed.txt"),
|
||||
_make_file_dict("new1", "new1.txt"),
|
||||
_make_file_dict("new2", "new2.txt"),
|
||||
]
|
||||
|
||||
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
|
||||
full_scan_mocks["skip_results"]["rename1"] = (
|
||||
True,
|
||||
"File renamed: 'old' -> 'renamed.txt'",
|
||||
)
|
||||
|
||||
full_scan_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, _unsupported = await _run_full_scan(
|
||||
full_scan_mocks, monkeypatch, page_files
|
||||
)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 from batch
|
||||
assert skipped == 2 # 1 folder + 1 unchanged
|
||||
|
||||
call_args = full_scan_mocks["download_and_index_mock"].call_args
|
||||
call_files = call_args[0][2]
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"new1", "new2"}
|
||||
|
||||
|
||||
async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
|
||||
"""Only max_files non-folder items are considered."""
|
||||
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
|
||||
|
||||
full_scan_mocks["download_and_index_mock"].return_value = (3, 0)
|
||||
|
||||
await _run_full_scan(full_scan_mocks, monkeypatch, page_files, max_files=3)
|
||||
|
||||
call_files = full_scan_mocks["download_and_index_mock"].call_args[0][2]
|
||||
assert len(call_files) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# D3-D5: _index_selected_files tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def selected_files_mocks(mock_dropbox_client, monkeypatch):
|
||||
"""Wire up mocks for _index_selected_files tests."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
get_file_results: dict[str, tuple[dict | None, str | None]] = {}
|
||||
|
||||
async def _fake_get_file(client, path):
|
||||
return get_file_results.get(path, (None, f"Not configured: {path}"))
|
||||
|
||||
monkeypatch.setattr(_mod, "get_file_by_path", _fake_get_file)
|
||||
|
||||
skip_results: dict[str, tuple[bool, str | None]] = {}
|
||||
|
||||
async def _fake_skip(session, file, search_space_id):
|
||||
return skip_results.get(file["id"], (False, None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_should_skip_file", _fake_skip)
|
||||
|
||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||
|
||||
from app.services.page_limit_service import PageLimitService as _RealPLS
|
||||
|
||||
mock_page_limit_instance = MagicMock()
|
||||
mock_page_limit_instance.get_page_usage = AsyncMock(return_value=(0, 999_999))
|
||||
mock_page_limit_instance.update_page_usage = AsyncMock()
|
||||
|
||||
class _MockPageLimitService:
|
||||
estimate_pages_from_metadata = staticmethod(
|
||||
_RealPLS.estimate_pages_from_metadata
|
||||
)
|
||||
|
||||
def __init__(self, session):
|
||||
self.get_page_usage = mock_page_limit_instance.get_page_usage
|
||||
self.update_page_usage = mock_page_limit_instance.update_page_usage
|
||||
|
||||
monkeypatch.setattr(_mod, "PageLimitService", _MockPageLimitService)
|
||||
|
||||
return {
|
||||
"dropbox_client": mock_dropbox_client,
|
||||
"session": mock_session,
|
||||
"get_file_results": get_file_results,
|
||||
"skip_results": skip_results,
|
||||
"download_and_index_mock": download_and_index_mock,
|
||||
}
|
||||
|
||||
|
||||
async def _run_selected(mocks, file_tuples):
|
||||
return await _index_selected_files(
|
||||
mocks["dropbox_client"],
|
||||
mocks["session"],
|
||||
file_tuples,
|
||||
connector_id=_CONNECTOR_ID,
|
||||
search_space_id=_SEARCH_SPACE_ID,
|
||||
user_id=_USER_ID,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_selected_files_single_file_indexed(selected_files_mocks):
|
||||
selected_files_mocks["get_file_results"]["/report.pdf"] = (
|
||||
_make_file_dict("f1", "report.pdf"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
||||
|
||||
indexed, skipped, _unsupported, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("/report.pdf", "report.pdf")],
|
||||
)
|
||||
|
||||
assert indexed == 1
|
||||
assert skipped == 0
|
||||
assert errors == []
|
||||
|
||||
|
||||
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
||||
selected_files_mocks["get_file_results"]["/first.txt"] = (
|
||||
_make_file_dict("f1", "first.txt"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["get_file_results"]["/mid.txt"] = (None, "HTTP 404")
|
||||
selected_files_mocks["get_file_results"]["/third.txt"] = (
|
||||
_make_file_dict("f3", "third.txt"),
|
||||
None,
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, _unsupported, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[
|
||||
("/first.txt", "first.txt"),
|
||||
("/mid.txt", "mid.txt"),
|
||||
("/third.txt", "third.txt"),
|
||||
],
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
assert len(errors) == 1
|
||||
assert "mid.txt" in errors[0]
|
||||
|
||||
|
||||
async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
||||
for path, fid, fname in [
|
||||
("/unchanged.txt", "s1", "unchanged.txt"),
|
||||
("/renamed.txt", "r1", "renamed.txt"),
|
||||
("/new1.txt", "n1", "new1.txt"),
|
||||
("/new2.txt", "n2", "new2.txt"),
|
||||
]:
|
||||
selected_files_mocks["get_file_results"][path] = (
|
||||
_make_file_dict(fid, fname),
|
||||
None,
|
||||
)
|
||||
|
||||
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
|
||||
selected_files_mocks["skip_results"]["r1"] = (
|
||||
True,
|
||||
"File renamed: 'old' -> 'renamed.txt'",
|
||||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, _unsupported, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[
|
||||
("/unchanged.txt", "unchanged.txt"),
|
||||
("/renamed.txt", "renamed.txt"),
|
||||
("/new1.txt", "new1.txt"),
|
||||
("/new2.txt", "new2.txt"),
|
||||
],
|
||||
)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 batch
|
||||
assert skipped == 1
|
||||
assert errors == []
|
||||
|
||||
mock = selected_files_mocks["download_and_index_mock"]
|
||||
call_files = mock.call_args[0][2]
|
||||
assert len(call_files) == 2
|
||||
assert {f["id"] for f in call_files} == {"n1", "n2"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E1-E4: _index_with_delta_sync tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_delta_sync_deletions_call_remove_document(monkeypatch):
|
||||
"""E1: deleted entries are processed via _remove_document."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
entries = [
|
||||
{
|
||||
".tag": "deleted",
|
||||
"name": "gone.txt",
|
||||
"path_lower": "/gone.txt",
|
||||
"id": "id:del1",
|
||||
},
|
||||
{
|
||||
".tag": "deleted",
|
||||
"name": "also_gone.pdf",
|
||||
"path_lower": "/also_gone.pdf",
|
||||
"id": "id:del2",
|
||||
},
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_changes = AsyncMock(return_value=(entries, "new-cursor", None))
|
||||
|
||||
remove_calls: list[str] = []
|
||||
|
||||
async def _fake_remove(session, file_id, search_space_id):
|
||||
remove_calls.append(file_id)
|
||||
|
||||
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
||||
monkeypatch.setattr(_mod, "_download_and_index", AsyncMock(return_value=(0, 0)))
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
_indexed, _skipped, _unsupported, cursor = await _index_with_delta_sync(
|
||||
mock_client,
|
||||
AsyncMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"old-cursor",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert sorted(remove_calls) == ["id:del1", "id:del2"]
|
||||
assert cursor == "new-cursor"
|
||||
|
||||
|
||||
async def test_delta_sync_upserts_filtered_and_downloaded(monkeypatch):
|
||||
"""E2: modified/new file entries go through skip filter then download+index."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
entries = [
|
||||
_make_file_dict("mod1", "modified1.txt"),
|
||||
_make_file_dict("mod2", "modified2.txt"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_changes = AsyncMock(return_value=(entries, "cursor-v2", None))
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_mock = AsyncMock(return_value=(2, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_mock)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
|
||||
mock_client,
|
||||
AsyncMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"cursor-v1",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
assert cursor == "cursor-v2"
|
||||
|
||||
downloaded_files = download_mock.call_args[0][2]
|
||||
assert len(downloaded_files) == 2
|
||||
assert {f["id"] for f in downloaded_files} == {"mod1", "mod2"}
|
||||
|
||||
|
||||
async def test_delta_sync_mix_deletions_and_upserts(monkeypatch):
|
||||
"""E3: deletions processed, then remaining upserts filtered and indexed."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
entries = [
|
||||
{
|
||||
".tag": "deleted",
|
||||
"name": "removed.txt",
|
||||
"path_lower": "/removed.txt",
|
||||
"id": "id:del1",
|
||||
},
|
||||
{
|
||||
".tag": "deleted",
|
||||
"name": "trashed.pdf",
|
||||
"path_lower": "/trashed.pdf",
|
||||
"id": "id:del2",
|
||||
},
|
||||
_make_file_dict("mod1", "updated.txt"),
|
||||
_make_file_dict("new1", "brandnew.docx"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_changes = AsyncMock(return_value=(entries, "final-cursor", None))
|
||||
|
||||
remove_calls: list[str] = []
|
||||
|
||||
async def _fake_remove(session, file_id, search_space_id):
|
||||
remove_calls.append(file_id)
|
||||
|
||||
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
||||
monkeypatch.setattr(
|
||||
_mod, "_should_skip_file", AsyncMock(return_value=(False, None))
|
||||
)
|
||||
|
||||
download_mock = AsyncMock(return_value=(2, 0))
|
||||
monkeypatch.setattr(_mod, "_download_and_index", download_mock)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
|
||||
mock_client,
|
||||
AsyncMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"old-cursor",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert sorted(remove_calls) == ["id:del1", "id:del2"]
|
||||
assert indexed == 2
|
||||
assert skipped == 0
|
||||
assert cursor == "final-cursor"
|
||||
|
||||
downloaded_files = download_mock.call_args[0][2]
|
||||
assert {f["id"] for f in downloaded_files} == {"mod1", "new1"}
|
||||
|
||||
|
||||
async def test_delta_sync_returns_new_cursor(monkeypatch):
|
||||
"""E4: the new cursor from the API response is returned."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_changes = AsyncMock(return_value=([], "brand-new-cursor-xyz", None))
|
||||
|
||||
monkeypatch.setattr(_mod, "_download_and_index", AsyncMock(return_value=(0, 0)))
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped, _unsupported, cursor = await _index_with_delta_sync(
|
||||
mock_client,
|
||||
AsyncMock(),
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
"old-cursor",
|
||||
mock_task_logger,
|
||||
MagicMock(),
|
||||
max_files=500,
|
||||
enable_summary=True,
|
||||
)
|
||||
|
||||
assert cursor == "brand-new-cursor-xyz"
|
||||
assert indexed == 0
|
||||
assert skipped == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# F1-F3: index_dropbox_files orchestrator tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator_mocks(monkeypatch):
|
||||
"""Wire up mocks for index_dropbox_files orchestrator tests."""
|
||||
import app.tasks.connector_indexers.dropbox_indexer as _mod
|
||||
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.config = {"_token_encrypted": False}
|
||||
mock_connector.last_indexed_at = None
|
||||
mock_connector.enable_summary = True
|
||||
|
||||
monkeypatch.setattr(
|
||||
_mod,
|
||||
"get_connector_by_id",
|
||||
AsyncMock(return_value=mock_connector),
|
||||
)
|
||||
|
||||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock())
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
mock_task_logger.log_task_success = AsyncMock()
|
||||
mock_task_logger.log_task_failure = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(_mod, "update_connector_last_indexed", AsyncMock())
|
||||
|
||||
full_scan_mock = AsyncMock(return_value=(5, 2, 0))
|
||||
monkeypatch.setattr(_mod, "_index_full_scan", full_scan_mock)
|
||||
|
||||
delta_sync_mock = AsyncMock(return_value=(3, 1, 0, "delta-cursor-new"))
|
||||
monkeypatch.setattr(_mod, "_index_with_delta_sync", delta_sync_mock)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_latest_cursor = AsyncMock(return_value=("latest-cursor-abc", None))
|
||||
monkeypatch.setattr(_mod, "DropboxClient", MagicMock(return_value=mock_client))
|
||||
|
||||
return {
|
||||
"connector": mock_connector,
|
||||
"full_scan_mock": full_scan_mock,
|
||||
"delta_sync_mock": delta_sync_mock,
|
||||
"mock_client": mock_client,
|
||||
}
|
||||
|
||||
|
||||
async def test_orchestrator_uses_delta_sync_when_cursor_and_last_indexed(
|
||||
orchestrator_mocks,
|
||||
):
|
||||
"""F1: with cursor + last_indexed_at + use_delta_sync, calls delta sync."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
connector = orchestrator_mocks["connector"]
|
||||
connector.config = {
|
||||
"_token_encrypted": False,
|
||||
"folder_cursors": {"/docs": "saved-cursor-123"},
|
||||
}
|
||||
connector.last_indexed_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
_indexed, _skipped, error, _unsupported = await index_dropbox_files(
|
||||
mock_session,
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
{
|
||||
"folders": [{"path": "/docs", "name": "Docs"}],
|
||||
"files": [],
|
||||
"indexing_options": {"use_delta_sync": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert error is None
|
||||
orchestrator_mocks["delta_sync_mock"].assert_called_once()
|
||||
orchestrator_mocks["full_scan_mock"].assert_not_called()
|
||||
|
||||
|
||||
async def test_orchestrator_falls_back_to_full_scan_without_cursor(
|
||||
orchestrator_mocks,
|
||||
):
|
||||
"""F2: without cursor, falls back to full scan."""
|
||||
connector = orchestrator_mocks["connector"]
|
||||
connector.config = {"_token_encrypted": False}
|
||||
connector.last_indexed_at = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
_indexed, _skipped, error, _unsupported = await index_dropbox_files(
|
||||
mock_session,
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
{
|
||||
"folders": [{"path": "/docs", "name": "Docs"}],
|
||||
"files": [],
|
||||
"indexing_options": {"use_delta_sync": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert error is None
|
||||
orchestrator_mocks["full_scan_mock"].assert_called_once()
|
||||
orchestrator_mocks["delta_sync_mock"].assert_not_called()
|
||||
|
||||
|
||||
async def test_orchestrator_persists_cursor_after_sync(orchestrator_mocks):
|
||||
"""F3: after sync, persists new cursor to connector config."""
|
||||
connector = orchestrator_mocks["connector"]
|
||||
connector.config = {"_token_encrypted": False}
|
||||
connector.last_indexed_at = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
await index_dropbox_files(
|
||||
mock_session,
|
||||
_CONNECTOR_ID,
|
||||
_SEARCH_SPACE_ID,
|
||||
_USER_ID,
|
||||
{
|
||||
"folders": [{"path": "/docs", "name": "Docs"}],
|
||||
"files": [],
|
||||
},
|
||||
)
|
||||
|
||||
assert "folder_cursors" in connector.config
|
||||
assert connector.config["folder_cursors"]["/docs"] == "latest-cursor-abc"
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
|
|||
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
||||
full_scan_mocks["batch_mock"].return_value = ([], 2, 0)
|
||||
|
||||
indexed, skipped = await _run_full_scan(full_scan_mocks)
|
||||
indexed, skipped, _unsupported = await _run_full_scan(full_scan_mocks)
|
||||
|
||||
assert indexed == 3 # 1 renamed + 2 from batch
|
||||
assert skipped == 1 # 1 unchanged
|
||||
|
|
@ -497,7 +497,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
|||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
indexed, skipped = await _index_with_delta_sync(
|
||||
indexed, skipped, _unsupported = await _index_with_delta_sync(
|
||||
MagicMock(),
|
||||
mock_session,
|
||||
MagicMock(),
|
||||
|
|
@ -589,7 +589,7 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
|
|||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
indexed, skipped, _unsup, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("f1", "report.pdf")],
|
||||
)
|
||||
|
|
@ -613,7 +613,7 @@ async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
|||
)
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
indexed, skipped, _unsup, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[("f1", "first.txt"), ("f2", "mid.txt"), ("f3", "third.txt")],
|
||||
)
|
||||
|
|
@ -647,7 +647,7 @@ async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
|||
|
||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, skipped, errors = await _run_selected(
|
||||
indexed, skipped, _unsup, errors = await _run_selected(
|
||||
selected_files_mocks,
|
||||
[
|
||||
("s1", "unchanged.txt"),
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ async def test_gdrive_files_within_quota_are_downloaded(gdrive_selected_mocks):
|
|||
)
|
||||
m["download_and_index_mock"].return_value = (3, 0)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
indexed, _skipped, _unsup, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")]
|
||||
)
|
||||
|
||||
|
|
@ -219,7 +219,9 @@ async def test_gdrive_files_exceeding_quota_rejected(gdrive_selected_mocks):
|
|||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(m, [("big", "huge.pdf")])
|
||||
indexed, _skipped, _unsup, errors = await _run_gdrive_selected(
|
||||
m, [("big", "huge.pdf")]
|
||||
)
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 1
|
||||
|
|
@ -239,7 +241,7 @@ async def test_gdrive_quota_mix_partial_indexing(gdrive_selected_mocks):
|
|||
)
|
||||
m["download_and_index_mock"].return_value = (2, 0)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
indexed, _skipped, _unsup, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz"), ("f3", "f3.xyz")]
|
||||
)
|
||||
|
||||
|
|
@ -299,7 +301,7 @@ async def test_gdrive_zero_quota_rejects_all(gdrive_selected_mocks):
|
|||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_gdrive_selected(
|
||||
indexed, _skipped, _unsup, errors = await _run_gdrive_selected(
|
||||
m, [("f1", "f1.xyz"), ("f2", "f2.xyz")]
|
||||
)
|
||||
|
||||
|
|
@ -384,7 +386,7 @@ async def test_gdrive_full_scan_skips_over_quota(gdrive_full_scan_mocks, monkeyp
|
|||
m["download_mock"].return_value = ([], 0)
|
||||
m["batch_mock"].return_value = ([], 2, 0)
|
||||
|
||||
_indexed, skipped = await _run_gdrive_full_scan(m)
|
||||
_indexed, skipped, _unsup = await _run_gdrive_full_scan(m)
|
||||
|
||||
call_files = m["download_mock"].call_args[0][1]
|
||||
assert len(call_files) == 2
|
||||
|
|
@ -459,7 +461,7 @@ async def test_gdrive_delta_sync_skips_over_quota(monkeypatch):
|
|||
mock_task_logger = MagicMock()
|
||||
mock_task_logger.log_task_progress = AsyncMock()
|
||||
|
||||
_indexed, skipped = await _mod._index_with_delta_sync(
|
||||
_indexed, skipped, _unsupported = await _mod._index_with_delta_sync(
|
||||
MagicMock(),
|
||||
session,
|
||||
MagicMock(),
|
||||
|
|
@ -552,7 +554,9 @@ async def test_onedrive_over_quota_rejected(onedrive_selected_mocks):
|
|||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_onedrive_selected(m, [("big", "huge.pdf")])
|
||||
indexed, _skipped, _unsup, errors = await _run_onedrive_selected(
|
||||
m, [("big", "huge.pdf")]
|
||||
)
|
||||
|
||||
assert indexed == 0
|
||||
assert len(errors) == 1
|
||||
|
|
@ -652,7 +656,7 @@ async def test_dropbox_over_quota_rejected(dropbox_selected_mocks):
|
|||
None,
|
||||
)
|
||||
|
||||
indexed, _skipped, errors = await _run_dropbox_selected(
|
||||
indexed, _skipped, _unsup, errors = await _run_dropbox_selected(
|
||||
m, [("/huge.pdf", "huge.pdf")]
|
||||
)
|
||||
|
||||
|
|
|
|||
0
surfsense_backend/tests/unit/connectors/__init__.py
Normal file
0
surfsense_backend/tests/unit/connectors/__init__.py
Normal file
123
surfsense_backend/tests/unit/connectors/test_dropbox_client.py
Normal file
123
surfsense_backend/tests/unit/connectors/test_dropbox_client.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Tests for DropboxClient delta-sync methods (get_latest_cursor, get_changes)."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.connectors.dropbox.client import DropboxClient
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_client() -> DropboxClient:
|
||||
"""Create a DropboxClient with a mocked DB session so no real DB needed."""
|
||||
client = DropboxClient.__new__(DropboxClient)
|
||||
client._session = MagicMock()
|
||||
client._connector_id = 1
|
||||
return client
|
||||
|
||||
|
||||
# ---------- C1: get_latest_cursor ----------
|
||||
|
||||
|
||||
async def test_get_latest_cursor_returns_cursor_string(monkeypatch):
|
||||
client = _make_client()
|
||||
|
||||
fake_resp = MagicMock()
|
||||
fake_resp.status_code = 200
|
||||
fake_resp.json.return_value = {"cursor": "AAHbKxRZ9enq…"}
|
||||
|
||||
monkeypatch.setattr(client, "_request", AsyncMock(return_value=fake_resp))
|
||||
|
||||
cursor, error = await client.get_latest_cursor("/my-folder")
|
||||
|
||||
assert cursor == "AAHbKxRZ9enq…"
|
||||
assert error is None
|
||||
client._request.assert_called_once_with(
|
||||
"/2/files/list_folder/get_latest_cursor",
|
||||
{
|
||||
"path": "/my-folder",
|
||||
"recursive": False,
|
||||
"include_non_downloadable_files": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------- C2: get_changes returns entries and new cursor ----------
|
||||
|
||||
|
||||
async def test_get_changes_returns_entries_and_cursor(monkeypatch):
|
||||
client = _make_client()
|
||||
|
||||
fake_resp = MagicMock()
|
||||
fake_resp.status_code = 200
|
||||
fake_resp.json.return_value = {
|
||||
"entries": [
|
||||
{".tag": "file", "name": "new.txt", "id": "id:abc"},
|
||||
{".tag": "deleted", "name": "old.txt"},
|
||||
],
|
||||
"cursor": "cursor-v2",
|
||||
"has_more": False,
|
||||
}
|
||||
monkeypatch.setattr(client, "_request", AsyncMock(return_value=fake_resp))
|
||||
|
||||
entries, new_cursor, error = await client.get_changes("cursor-v1")
|
||||
|
||||
assert error is None
|
||||
assert new_cursor == "cursor-v2"
|
||||
assert len(entries) == 2
|
||||
assert entries[0]["name"] == "new.txt"
|
||||
assert entries[1][".tag"] == "deleted"
|
||||
|
||||
|
||||
# ---------- C3: get_changes handles pagination ----------
|
||||
|
||||
|
||||
async def test_get_changes_handles_pagination(monkeypatch):
|
||||
client = _make_client()
|
||||
|
||||
page1 = MagicMock()
|
||||
page1.status_code = 200
|
||||
page1.json.return_value = {
|
||||
"entries": [{".tag": "file", "name": "a.txt", "id": "id:a"}],
|
||||
"cursor": "cursor-page2",
|
||||
"has_more": True,
|
||||
}
|
||||
page2 = MagicMock()
|
||||
page2.status_code = 200
|
||||
page2.json.return_value = {
|
||||
"entries": [{".tag": "file", "name": "b.txt", "id": "id:b"}],
|
||||
"cursor": "cursor-final",
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
request_mock = AsyncMock(side_effect=[page1, page2])
|
||||
monkeypatch.setattr(client, "_request", request_mock)
|
||||
|
||||
entries, new_cursor, error = await client.get_changes("cursor-v1")
|
||||
|
||||
assert error is None
|
||||
assert new_cursor == "cursor-final"
|
||||
assert len(entries) == 2
|
||||
assert {e["name"] for e in entries} == {"a.txt", "b.txt"}
|
||||
assert request_mock.call_count == 2
|
||||
|
||||
|
||||
# ---------- C4: get_changes raises on 401 ----------
|
||||
|
||||
|
||||
async def test_get_changes_returns_error_on_401(monkeypatch):
|
||||
client = _make_client()
|
||||
|
||||
fake_resp = MagicMock()
|
||||
fake_resp.status_code = 401
|
||||
fake_resp.text = "Unauthorized"
|
||||
|
||||
monkeypatch.setattr(client, "_request", AsyncMock(return_value=fake_resp))
|
||||
|
||||
entries, new_cursor, error = await client.get_changes("old-cursor")
|
||||
|
||||
assert error is not None
|
||||
assert "401" in error
|
||||
assert entries == []
|
||||
assert new_cursor is None
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
"""Tests for Dropbox file type filtering (should_skip_file)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.connectors.dropbox.file_types import should_skip_file
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Structural skips (independent of ETL service)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_folder_item_is_skipped():
|
||||
item = {".tag": "folder", "name": "My Folder"}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True
|
||||
assert ext is None
|
||||
|
||||
|
||||
def test_paper_file_is_not_skipped():
|
||||
item = {".tag": "file", "name": "notes.paper", "is_downloadable": False}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is False
|
||||
assert ext is None
|
||||
|
||||
|
||||
def test_non_downloadable_item_is_skipped():
|
||||
item = {".tag": "file", "name": "locked.gdoc", "is_downloadable": False}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True
|
||||
assert ext is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extension-based skips (require ETL service context)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"archive.zip",
|
||||
"backup.tar",
|
||||
"data.gz",
|
||||
"stuff.rar",
|
||||
"pack.7z",
|
||||
"program.exe",
|
||||
"lib.dll",
|
||||
"module.so",
|
||||
"image.dmg",
|
||||
"disk.iso",
|
||||
"movie.mov",
|
||||
"clip.avi",
|
||||
"video.mkv",
|
||||
"film.wmv",
|
||||
"stream.flv",
|
||||
"favicon.ico",
|
||||
"raw.cr2",
|
||||
"photo.nef",
|
||||
"image.arw",
|
||||
"pic.dng",
|
||||
"design.psd",
|
||||
"vector.ai",
|
||||
"mockup.sketch",
|
||||
"proto.fig",
|
||||
"font.ttf",
|
||||
"font.otf",
|
||||
"font.woff",
|
||||
"font.woff2",
|
||||
"model.stl",
|
||||
"scene.fbx",
|
||||
"mesh.blend",
|
||||
"local.db",
|
||||
"data.sqlite",
|
||||
"access.mdb",
|
||||
],
|
||||
)
|
||||
def test_non_parseable_extensions_are_skipped(filename, mocker):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
item = {".tag": "file", "name": filename}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True, f"{filename} should be skipped"
|
||||
assert ext is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"report.pdf",
|
||||
"document.docx",
|
||||
"sheet.xlsx",
|
||||
"slides.pptx",
|
||||
"readme.txt",
|
||||
"data.csv",
|
||||
"page.html",
|
||||
"notes.md",
|
||||
"config.json",
|
||||
"feed.xml",
|
||||
],
|
||||
)
|
||||
def test_parseable_documents_are_not_skipped(filename, mocker):
|
||||
"""Files in plaintext/direct_convert/universal document sets are never skipped."""
|
||||
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", service)
|
||||
item = {".tag": "file", "name": filename}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is False, f"{filename} should NOT be skipped with {service}"
|
||||
assert ext is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
["photo.jpg", "image.jpeg", "screenshot.png", "scan.bmp", "page.tiff", "doc.tif"],
|
||||
)
|
||||
def test_universal_images_are_not_skipped(filename, mocker):
|
||||
"""Images supported by all parsers are never skipped."""
|
||||
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", service)
|
||||
item = {".tag": "file", "name": filename}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is False, f"{filename} should NOT be skipped with {service}"
|
||||
assert ext is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename,service,expected_skip",
|
||||
[
|
||||
("old.doc", "DOCLING", True),
|
||||
("old.doc", "LLAMACLOUD", False),
|
||||
("old.doc", "UNSTRUCTURED", False),
|
||||
("legacy.xls", "DOCLING", True),
|
||||
("legacy.xls", "LLAMACLOUD", False),
|
||||
("legacy.xls", "UNSTRUCTURED", False),
|
||||
("deck.ppt", "DOCLING", True),
|
||||
("deck.ppt", "LLAMACLOUD", False),
|
||||
("deck.ppt", "UNSTRUCTURED", False),
|
||||
("icon.svg", "DOCLING", True),
|
||||
("icon.svg", "LLAMACLOUD", False),
|
||||
("anim.gif", "DOCLING", True),
|
||||
("anim.gif", "LLAMACLOUD", False),
|
||||
("photo.webp", "DOCLING", False),
|
||||
("photo.webp", "LLAMACLOUD", False),
|
||||
("photo.webp", "UNSTRUCTURED", True),
|
||||
("live.heic", "DOCLING", True),
|
||||
("live.heic", "UNSTRUCTURED", False),
|
||||
("macro.docm", "DOCLING", True),
|
||||
("macro.docm", "LLAMACLOUD", False),
|
||||
("mail.eml", "DOCLING", True),
|
||||
("mail.eml", "UNSTRUCTURED", False),
|
||||
],
|
||||
)
|
||||
def test_parser_specific_extensions(filename, service, expected_skip, mocker):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", service)
|
||||
item = {".tag": "file", "name": filename}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is expected_skip, (
|
||||
f"{filename} with {service}: expected skip={expected_skip}"
|
||||
)
|
||||
if expected_skip:
|
||||
assert ext is not None
|
||||
else:
|
||||
assert ext is None
|
||||
|
||||
|
||||
def test_returns_unsupported_extension(mocker):
|
||||
"""When a file is skipped due to unsupported extension, the ext string is returned."""
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
item = {".tag": "file", "name": "old.doc"}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True
|
||||
assert ext == ".doc"
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
"""Test that Dropbox re-auth preserves folder_cursors in connector config."""
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_reauth_preserves_folder_cursors():
|
||||
"""G1: re-authentication preserves folder_cursors alongside cursor."""
|
||||
old_config = {
|
||||
"access_token": "old-token-enc",
|
||||
"refresh_token": "old-refresh-enc",
|
||||
"cursor": "old-cursor-abc",
|
||||
"folder_cursors": {"/docs": "cursor-docs-123", "/photos": "cursor-photos-456"},
|
||||
"_token_encrypted": True,
|
||||
"auth_expired": True,
|
||||
}
|
||||
|
||||
new_connector_config = {
|
||||
"access_token": "new-token-enc",
|
||||
"refresh_token": "new-refresh-enc",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 14400,
|
||||
"expires_at": "2026-04-06T16:00:00+00:00",
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
existing_cursor = old_config.get("cursor")
|
||||
existing_folder_cursors = old_config.get("folder_cursors")
|
||||
merged_config = {
|
||||
**new_connector_config,
|
||||
"cursor": existing_cursor,
|
||||
"folder_cursors": existing_folder_cursors,
|
||||
"auth_expired": False,
|
||||
}
|
||||
|
||||
assert merged_config["access_token"] == "new-token-enc"
|
||||
assert merged_config["cursor"] == "old-cursor-abc"
|
||||
assert merged_config["folder_cursors"] == {
|
||||
"/docs": "cursor-docs-123",
|
||||
"/photos": "cursor-photos-456",
|
||||
}
|
||||
assert merged_config["auth_expired"] is False
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
"""Tests for Google Drive file type filtering."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.connectors.google_drive.file_types import should_skip_by_extension
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"malware.exe",
|
||||
"archive.zip",
|
||||
"video.mov",
|
||||
"font.woff2",
|
||||
"model.blend",
|
||||
],
|
||||
)
|
||||
def test_unsupported_extensions_are_skipped_regardless_of_service(filename, mocker):
|
||||
"""Truly unsupported files are skipped no matter which ETL service is configured."""
|
||||
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", service)
|
||||
skip, _ext = should_skip_by_extension(filename)
|
||||
assert skip is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"report.pdf",
|
||||
"doc.docx",
|
||||
"sheet.xlsx",
|
||||
"slides.pptx",
|
||||
"readme.txt",
|
||||
"data.csv",
|
||||
"photo.png",
|
||||
"notes.md",
|
||||
],
|
||||
)
|
||||
def test_universal_extensions_are_not_skipped(filename, mocker):
|
||||
"""Files supported by all parsers (or handled by plaintext/direct_convert) are never skipped."""
|
||||
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", service)
|
||||
skip, ext = should_skip_by_extension(filename)
|
||||
assert skip is False, f"{filename} should NOT be skipped with {service}"
|
||||
assert ext is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename,service,expected_skip",
|
||||
[
|
||||
("macro.docm", "DOCLING", True),
|
||||
("macro.docm", "LLAMACLOUD", False),
|
||||
("mail.eml", "DOCLING", True),
|
||||
("mail.eml", "UNSTRUCTURED", False),
|
||||
("photo.gif", "DOCLING", True),
|
||||
("photo.gif", "LLAMACLOUD", False),
|
||||
("photo.heic", "UNSTRUCTURED", False),
|
||||
("photo.heic", "DOCLING", True),
|
||||
],
|
||||
)
|
||||
def test_parser_specific_extensions(filename, service, expected_skip, mocker):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", service)
|
||||
skip, ext = should_skip_by_extension(filename)
|
||||
assert skip is expected_skip, (
|
||||
f"{filename} with {service}: expected skip={expected_skip}"
|
||||
)
|
||||
if expected_skip:
|
||||
assert ext is not None, "unsupported extension should be returned"
|
||||
else:
|
||||
assert ext is None
|
||||
|
||||
|
||||
def test_returns_unsupported_extension(mocker):
|
||||
"""When a file is skipped, the unsupported extension string is returned."""
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
skip, ext = should_skip_by_extension("macro.docm")
|
||||
assert skip is True
|
||||
assert ext == ".docm"
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
"""Tests for OneDrive file type filtering."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.connectors.onedrive.file_types import should_skip_file
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Structural skips (independent of ETL service)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_folder_is_skipped():
|
||||
item = {"folder": {}, "name": "My Folder"}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True
|
||||
assert ext is None
|
||||
|
||||
|
||||
def test_remote_item_is_skipped():
|
||||
item = {"remoteItem": {}, "name": "shared.docx"}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True
|
||||
assert ext is None
|
||||
|
||||
|
||||
def test_package_is_skipped():
|
||||
item = {"package": {}, "name": "notebook"}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True
|
||||
assert ext is None
|
||||
|
||||
|
||||
def test_onenote_is_skipped():
|
||||
item = {"name": "notes", "file": {"mimeType": "application/msonenote"}}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True
|
||||
assert ext is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extension-based skips (require ETL service context)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"malware.exe",
|
||||
"archive.zip",
|
||||
"video.mov",
|
||||
"font.woff2",
|
||||
"model.blend",
|
||||
],
|
||||
)
|
||||
def test_unsupported_extensions_are_skipped(filename, mocker):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
item = {"name": filename, "file": {"mimeType": "application/octet-stream"}}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True, f"{filename} should be skipped"
|
||||
assert ext is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"report.pdf",
|
||||
"doc.docx",
|
||||
"sheet.xlsx",
|
||||
"slides.pptx",
|
||||
"readme.txt",
|
||||
"data.csv",
|
||||
"photo.png",
|
||||
"notes.md",
|
||||
],
|
||||
)
|
||||
def test_universal_files_are_not_skipped(filename, mocker):
|
||||
for service in ("DOCLING", "LLAMACLOUD", "UNSTRUCTURED"):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", service)
|
||||
item = {"name": filename, "file": {"mimeType": "application/octet-stream"}}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is False, f"{filename} should NOT be skipped with {service}"
|
||||
assert ext is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename,service,expected_skip",
|
||||
[
|
||||
("macro.docm", "DOCLING", True),
|
||||
("macro.docm", "LLAMACLOUD", False),
|
||||
("mail.eml", "DOCLING", True),
|
||||
("mail.eml", "UNSTRUCTURED", False),
|
||||
("photo.heic", "UNSTRUCTURED", False),
|
||||
("photo.heic", "DOCLING", True),
|
||||
],
|
||||
)
|
||||
def test_parser_specific_extensions(filename, service, expected_skip, mocker):
|
||||
mocker.patch("app.config.config.ETL_SERVICE", service)
|
||||
item = {"name": filename, "file": {"mimeType": "application/octet-stream"}}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is expected_skip, (
|
||||
f"{filename} with {service}: expected skip={expected_skip}"
|
||||
)
|
||||
if expected_skip:
|
||||
assert ext is not None
|
||||
else:
|
||||
assert ext is None
|
||||
|
||||
|
||||
def test_returns_unsupported_extension(mocker):
|
||||
"""When a file is skipped due to unsupported extension, the ext string is returned."""
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
item = {"name": "mail.eml", "file": {"mimeType": "application/octet-stream"}}
|
||||
skip, ext = should_skip_file(item)
|
||||
assert skip is True
|
||||
assert ext == ".eml"
|
||||
27
surfsense_backend/tests/unit/etl_pipeline/conftest.py
Normal file
27
surfsense_backend/tests/unit/etl_pipeline/conftest.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Pre-register the etl_pipeline package to avoid circular imports during unit tests."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
_BACKEND = Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def _stub_package(dotted: str, fs_dir: Path) -> None:
|
||||
if dotted not in sys.modules:
|
||||
mod = types.ModuleType(dotted)
|
||||
mod.__path__ = [str(fs_dir)]
|
||||
mod.__package__ = dotted
|
||||
sys.modules[dotted] = mod
|
||||
|
||||
parts = dotted.split(".")
|
||||
if len(parts) > 1:
|
||||
parent_dotted = ".".join(parts[:-1])
|
||||
parent = sys.modules.get(parent_dotted)
|
||||
if parent is not None:
|
||||
setattr(parent, parts[-1], sys.modules[dotted])
|
||||
|
||||
|
||||
_stub_package("app", _BACKEND / "app")
|
||||
_stub_package("app.etl_pipeline", _BACKEND / "app" / "etl_pipeline")
|
||||
_stub_package("app.etl_pipeline.parsers", _BACKEND / "app" / "etl_pipeline" / "parsers")
|
||||
|
|
@ -0,0 +1,461 @@
|
|||
"""Tests for EtlPipelineService -- the unified ETL pipeline public interface."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.etl_pipeline.etl_document import EtlRequest
|
||||
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
async def test_extract_txt_file_returns_markdown(tmp_path):
|
||||
"""Tracer bullet: a .txt file is read and returned as-is in an EtlResult."""
|
||||
txt_file = tmp_path / "hello.txt"
|
||||
txt_file.write_text("Hello, world!", encoding="utf-8")
|
||||
|
||||
service = EtlPipelineService()
|
||||
result = await service.extract(
|
||||
EtlRequest(file_path=str(txt_file), filename="hello.txt")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "Hello, world!"
|
||||
assert result.etl_service == "PLAINTEXT"
|
||||
assert result.content_type == "plaintext"
|
||||
|
||||
|
||||
async def test_extract_md_file(tmp_path):
|
||||
"""A .md file is classified as PLAINTEXT and extracted."""
|
||||
md_file = tmp_path / "readme.md"
|
||||
md_file.write_text("# Title\n\nBody text.", encoding="utf-8")
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(md_file), filename="readme.md")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Title\n\nBody text."
|
||||
assert result.etl_service == "PLAINTEXT"
|
||||
assert result.content_type == "plaintext"
|
||||
|
||||
|
||||
async def test_extract_markdown_file(tmp_path):
|
||||
"""A .markdown file is classified as PLAINTEXT and extracted."""
|
||||
md_file = tmp_path / "notes.markdown"
|
||||
md_file.write_text("Some notes.", encoding="utf-8")
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(md_file), filename="notes.markdown")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "Some notes."
|
||||
assert result.etl_service == "PLAINTEXT"
|
||||
|
||||
|
||||
async def test_extract_python_file(tmp_path):
|
||||
"""A .py source code file is classified as PLAINTEXT."""
|
||||
py_file = tmp_path / "script.py"
|
||||
py_file.write_text("print('hello')", encoding="utf-8")
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(py_file), filename="script.py")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "print('hello')"
|
||||
assert result.etl_service == "PLAINTEXT"
|
||||
assert result.content_type == "plaintext"
|
||||
|
||||
|
||||
async def test_extract_js_file(tmp_path):
|
||||
"""A .js source code file is classified as PLAINTEXT."""
|
||||
js_file = tmp_path / "app.js"
|
||||
js_file.write_text("console.log('hi');", encoding="utf-8")
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(js_file), filename="app.js")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "console.log('hi');"
|
||||
assert result.etl_service == "PLAINTEXT"
|
||||
|
||||
|
||||
async def test_extract_csv_returns_markdown_table(tmp_path):
|
||||
"""A .csv file is converted to a markdown table."""
|
||||
csv_file = tmp_path / "data.csv"
|
||||
csv_file.write_text("name,age\nAlice,30\nBob,25\n", encoding="utf-8")
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(csv_file), filename="data.csv")
|
||||
)
|
||||
|
||||
assert "| name | age |" in result.markdown_content
|
||||
assert "| Alice | 30 |" in result.markdown_content
|
||||
assert result.etl_service == "DIRECT_CONVERT"
|
||||
assert result.content_type == "direct_convert"
|
||||
|
||||
|
||||
async def test_extract_tsv_returns_markdown_table(tmp_path):
|
||||
"""A .tsv file is converted to a markdown table."""
|
||||
tsv_file = tmp_path / "data.tsv"
|
||||
tsv_file.write_text("x\ty\n1\t2\n", encoding="utf-8")
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(tsv_file), filename="data.tsv")
|
||||
)
|
||||
|
||||
assert "| x | y |" in result.markdown_content
|
||||
assert result.etl_service == "DIRECT_CONVERT"
|
||||
|
||||
|
||||
async def test_extract_html_returns_markdown(tmp_path):
|
||||
"""An .html file is converted to markdown."""
|
||||
html_file = tmp_path / "page.html"
|
||||
html_file.write_text("<h1>Title</h1><p>Body</p>", encoding="utf-8")
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(html_file), filename="page.html")
|
||||
)
|
||||
|
||||
assert "Title" in result.markdown_content
|
||||
assert "Body" in result.markdown_content
|
||||
assert result.etl_service == "DIRECT_CONVERT"
|
||||
|
||||
|
||||
async def test_extract_mp3_returns_transcription(tmp_path, mocker):
|
||||
"""An .mp3 audio file is transcribed via litellm.atranscription."""
|
||||
audio_file = tmp_path / "recording.mp3"
|
||||
audio_file.write_bytes(b"\x00" * 100)
|
||||
|
||||
mocker.patch("app.config.config.STT_SERVICE", "openai/whisper-1")
|
||||
mocker.patch("app.config.config.STT_SERVICE_API_KEY", "fake-key")
|
||||
mocker.patch("app.config.config.STT_SERVICE_API_BASE", None)
|
||||
|
||||
mock_transcription = mocker.patch(
|
||||
"app.etl_pipeline.parsers.audio.atranscription",
|
||||
return_value={"text": "Hello from audio"},
|
||||
)
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(audio_file), filename="recording.mp3")
|
||||
)
|
||||
|
||||
assert "Hello from audio" in result.markdown_content
|
||||
assert result.etl_service == "AUDIO"
|
||||
assert result.content_type == "audio"
|
||||
mock_transcription.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 7 - DOCLING document parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_extract_pdf_with_docling(tmp_path, mocker):
|
||||
"""A .pdf file with ETL_SERVICE=DOCLING returns parsed markdown."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {"content": "# Parsed PDF"}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# Parsed PDF"
|
||||
assert result.etl_service == "DOCLING"
|
||||
assert result.content_type == "document"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 8 - UNSTRUCTURED document parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_extract_pdf_with_unstructured(tmp_path, mocker):
|
||||
"""A .pdf file with ETL_SERVICE=UNSTRUCTURED returns parsed markdown."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "UNSTRUCTURED")
|
||||
|
||||
class FakeDoc:
|
||||
def __init__(self, text):
|
||||
self.page_content = text
|
||||
|
||||
fake_loader_instance = mocker.AsyncMock()
|
||||
fake_loader_instance.aload.return_value = [
|
||||
FakeDoc("Page 1 content"),
|
||||
FakeDoc("Page 2 content"),
|
||||
]
|
||||
mocker.patch(
|
||||
"langchain_unstructured.UnstructuredLoader",
|
||||
return_value=fake_loader_instance,
|
||||
)
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
assert "Page 1 content" in result.markdown_content
|
||||
assert "Page 2 content" in result.markdown_content
|
||||
assert result.etl_service == "UNSTRUCTURED"
|
||||
assert result.content_type == "document"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 9 - LLAMACLOUD document parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_extract_pdf_with_llamacloud(tmp_path, mocker):
|
||||
"""A .pdf file with ETL_SERVICE=LLAMACLOUD returns parsed markdown."""
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 fake content " * 10)
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "LLAMACLOUD")
|
||||
mocker.patch("app.config.config.LLAMA_CLOUD_API_KEY", "fake-key", create=True)
|
||||
|
||||
class FakeDoc:
|
||||
text = "# LlamaCloud parsed"
|
||||
|
||||
class FakeJobResult:
|
||||
pages = []
|
||||
|
||||
def get_markdown_documents(self, split_by_page=True):
|
||||
return [FakeDoc()]
|
||||
|
||||
fake_parser = mocker.AsyncMock()
|
||||
fake_parser.aparse.return_value = FakeJobResult()
|
||||
mocker.patch(
|
||||
"llama_cloud_services.LlamaParse",
|
||||
return_value=fake_parser,
|
||||
)
|
||||
mocker.patch(
|
||||
"llama_cloud_services.parse.utils.ResultType",
|
||||
mocker.MagicMock(MD="md"),
|
||||
)
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf", estimated_pages=5)
|
||||
)
|
||||
|
||||
assert result.markdown_content == "# LlamaCloud parsed"
|
||||
assert result.etl_service == "LLAMACLOUD"
|
||||
assert result.content_type == "document"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 10 - unknown extension falls through to document ETL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_unknown_extension_uses_document_etl(tmp_path, mocker):
|
||||
"""An allowlisted document extension (.docx) routes to the document ETL path."""
|
||||
docx_file = tmp_path / "doc.docx"
|
||||
docx_file.write_bytes(b"PK fake docx")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
fake_docling = mocker.AsyncMock()
|
||||
fake_docling.process_document.return_value = {"content": "Docx content"}
|
||||
mocker.patch(
|
||||
"app.services.docling_service.create_docling_service",
|
||||
return_value=fake_docling,
|
||||
)
|
||||
|
||||
result = await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(docx_file), filename="doc.docx")
|
||||
)
|
||||
|
||||
assert result.markdown_content == "Docx content"
|
||||
assert result.content_type == "document"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 11 - EtlRequest validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_etl_request_requires_filename():
|
||||
"""EtlRequest rejects missing filename."""
|
||||
with pytest.raises(ValueError, match="filename must not be empty"):
|
||||
EtlRequest(file_path="/tmp/some.txt", filename="")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 12 - unknown ETL_SERVICE raises EtlServiceUnavailableError
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_unknown_etl_service_raises(tmp_path, mocker):
|
||||
"""An unknown ETL_SERVICE raises EtlServiceUnavailableError."""
|
||||
from app.etl_pipeline.exceptions import EtlServiceUnavailableError
|
||||
|
||||
pdf_file = tmp_path / "report.pdf"
|
||||
pdf_file.write_bytes(b"%PDF fake")
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "NONEXISTENT")
|
||||
|
||||
with pytest.raises(EtlServiceUnavailableError, match="Unknown ETL_SERVICE"):
|
||||
await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 13 - unsupported file types are rejected before reaching any parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_unknown_extension_classified_as_unsupported():
|
||||
"""An unknown extension defaults to UNSUPPORTED (allowlist behaviour)."""
|
||||
from app.etl_pipeline.file_classifier import FileCategory, classify_file
|
||||
|
||||
assert classify_file("random.xyz") == FileCategory.UNSUPPORTED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"malware.exe",
|
||||
"archive.zip",
|
||||
"video.mov",
|
||||
"font.woff2",
|
||||
"model.blend",
|
||||
"data.parquet",
|
||||
"package.deb",
|
||||
"firmware.bin",
|
||||
],
|
||||
)
|
||||
def test_unsupported_extensions_classified_correctly(filename):
|
||||
"""Extensions not in any allowlist are classified as UNSUPPORTED."""
|
||||
from app.etl_pipeline.file_classifier import FileCategory, classify_file
|
||||
|
||||
assert classify_file(filename) == FileCategory.UNSUPPORTED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename,expected",
|
||||
[
|
||||
("report.pdf", "document"),
|
||||
("doc.docx", "document"),
|
||||
("slides.pptx", "document"),
|
||||
("sheet.xlsx", "document"),
|
||||
("photo.png", "document"),
|
||||
("photo.jpg", "document"),
|
||||
("book.epub", "document"),
|
||||
("letter.odt", "document"),
|
||||
("readme.md", "plaintext"),
|
||||
("data.csv", "direct_convert"),
|
||||
],
|
||||
)
|
||||
def test_parseable_extensions_classified_correctly(filename, expected):
|
||||
"""Parseable files are classified into their correct category."""
|
||||
from app.etl_pipeline.file_classifier import FileCategory, classify_file
|
||||
|
||||
result = classify_file(filename)
|
||||
assert result != FileCategory.UNSUPPORTED
|
||||
assert result.value == expected
|
||||
|
||||
|
||||
async def test_extract_unsupported_file_raises_error(tmp_path):
|
||||
"""EtlPipelineService.extract() raises EtlUnsupportedFileError for .exe files."""
|
||||
from app.etl_pipeline.exceptions import EtlUnsupportedFileError
|
||||
|
||||
exe_file = tmp_path / "program.exe"
|
||||
exe_file.write_bytes(b"\x00" * 10)
|
||||
|
||||
with pytest.raises(EtlUnsupportedFileError, match="not supported"):
|
||||
await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(exe_file), filename="program.exe")
|
||||
)
|
||||
|
||||
|
||||
async def test_extract_zip_raises_unsupported_error(tmp_path):
|
||||
"""EtlPipelineService.extract() raises EtlUnsupportedFileError for .zip archives."""
|
||||
from app.etl_pipeline.exceptions import EtlUnsupportedFileError
|
||||
|
||||
zip_file = tmp_path / "archive.zip"
|
||||
zip_file.write_bytes(b"PK\x03\x04")
|
||||
|
||||
with pytest.raises(EtlUnsupportedFileError, match="not supported"):
|
||||
await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(zip_file), filename="archive.zip")
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 14 - should_skip_for_service (per-parser document filtering)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename,etl_service,expected_skip",
|
||||
[
|
||||
("file.eml", "DOCLING", True),
|
||||
("file.eml", "UNSTRUCTURED", False),
|
||||
("file.docm", "LLAMACLOUD", False),
|
||||
("file.docm", "DOCLING", True),
|
||||
("file.txt", "DOCLING", False),
|
||||
("file.csv", "LLAMACLOUD", False),
|
||||
("file.mp3", "UNSTRUCTURED", False),
|
||||
("file.exe", "LLAMACLOUD", True),
|
||||
("file.pdf", "DOCLING", False),
|
||||
("file.webp", "DOCLING", False),
|
||||
("file.webp", "UNSTRUCTURED", True),
|
||||
("file.gif", "LLAMACLOUD", False),
|
||||
("file.gif", "DOCLING", True),
|
||||
("file.heic", "UNSTRUCTURED", False),
|
||||
("file.heic", "DOCLING", True),
|
||||
("file.svg", "LLAMACLOUD", False),
|
||||
("file.svg", "DOCLING", True),
|
||||
("file.p7s", "UNSTRUCTURED", False),
|
||||
("file.p7s", "LLAMACLOUD", True),
|
||||
],
|
||||
)
|
||||
def test_should_skip_for_service(filename, etl_service, expected_skip):
|
||||
from app.etl_pipeline.file_classifier import should_skip_for_service
|
||||
|
||||
assert should_skip_for_service(filename, etl_service) is expected_skip, (
|
||||
f"{filename} with {etl_service}: expected skip={expected_skip}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice 14b - ETL pipeline rejects per-parser incompatible documents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_extract_docm_with_docling_raises_unsupported(tmp_path, mocker):
|
||||
"""Docling cannot parse .docm -- pipeline should reject before dispatching."""
|
||||
from app.etl_pipeline.exceptions import EtlUnsupportedFileError
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
docm_file = tmp_path / "macro.docm"
|
||||
docm_file.write_bytes(b"\x00" * 10)
|
||||
|
||||
with pytest.raises(EtlUnsupportedFileError, match="not supported by DOCLING"):
|
||||
await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(docm_file), filename="macro.docm")
|
||||
)
|
||||
|
||||
|
||||
async def test_extract_eml_with_docling_raises_unsupported(tmp_path, mocker):
|
||||
"""Docling cannot parse .eml -- pipeline should reject before dispatching."""
|
||||
from app.etl_pipeline.exceptions import EtlUnsupportedFileError
|
||||
|
||||
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
|
||||
|
||||
eml_file = tmp_path / "mail.eml"
|
||||
eml_file.write_bytes(b"From: test@example.com")
|
||||
|
||||
with pytest.raises(EtlUnsupportedFileError, match="not supported by DOCLING"):
|
||||
await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(eml_file), filename="mail.eml")
|
||||
)
|
||||
0
surfsense_backend/tests/unit/services/__init__.py
Normal file
0
surfsense_backend/tests/unit/services/__init__.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Test that DoclingService does NOT restrict allowed_formats, letting Docling
|
||||
accept all its supported formats (PDF, DOCX, PPTX, XLSX, IMAGE, etc.)."""
|
||||
|
||||
from enum import Enum
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeInputFormat(Enum):
|
||||
PDF = "pdf"
|
||||
IMAGE = "image"
|
||||
DOCX = "docx"
|
||||
PPTX = "pptx"
|
||||
XLSX = "xlsx"
|
||||
|
||||
|
||||
def test_docling_service_does_not_restrict_allowed_formats():
|
||||
"""DoclingService should NOT pass allowed_formats to DocumentConverter,
|
||||
so Docling defaults to accepting every InputFormat it supports."""
|
||||
|
||||
mock_converter_cls = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
|
||||
fake_pipeline_options_cls = MagicMock()
|
||||
fake_pipeline_options = MagicMock()
|
||||
fake_pipeline_options_cls.return_value = fake_pipeline_options
|
||||
|
||||
fake_pdf_format_option_cls = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"docling": MagicMock(),
|
||||
"docling.backend": MagicMock(),
|
||||
"docling.backend.pypdfium2_backend": MagicMock(
|
||||
PyPdfiumDocumentBackend=mock_backend
|
||||
),
|
||||
"docling.datamodel": MagicMock(),
|
||||
"docling.datamodel.base_models": MagicMock(InputFormat=_FakeInputFormat),
|
||||
"docling.datamodel.pipeline_options": MagicMock(
|
||||
PdfPipelineOptions=fake_pipeline_options_cls
|
||||
),
|
||||
"docling.document_converter": MagicMock(
|
||||
DocumentConverter=mock_converter_cls,
|
||||
PdfFormatOption=fake_pdf_format_option_cls,
|
||||
),
|
||||
},
|
||||
):
|
||||
from importlib import reload
|
||||
|
||||
import app.services.docling_service as mod
|
||||
|
||||
reload(mod)
|
||||
|
||||
mod.DoclingService()
|
||||
|
||||
call_kwargs = mock_converter_cls.call_args
|
||||
assert call_kwargs is not None, "DocumentConverter was never called"
|
||||
|
||||
_, kwargs = call_kwargs
|
||||
assert "allowed_formats" not in kwargs, (
|
||||
f"allowed_formats should not be passed — let Docling accept all formats. "
|
||||
f"Got: {kwargs.get('allowed_formats')}"
|
||||
)
|
||||
assert _FakeInputFormat.PDF in kwargs.get("format_options", {}), (
|
||||
"format_options should still configure PDF pipeline options"
|
||||
)
|
||||
0
surfsense_backend/tests/unit/utils/__init__.py
Normal file
0
surfsense_backend/tests/unit/utils/__init__.py
Normal file
154
surfsense_backend/tests/unit/utils/test_file_extensions.py
Normal file
154
surfsense_backend/tests/unit/utils/test_file_extensions.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Tests for the DOCUMENT_EXTENSIONS allowlist module."""
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_pdf_is_supported_document():
|
||||
from app.utils.file_extensions import is_supported_document_extension
|
||||
|
||||
assert is_supported_document_extension("report.pdf") is True
|
||||
|
||||
|
||||
def test_exe_is_not_supported_document():
|
||||
from app.utils.file_extensions import is_supported_document_extension
|
||||
|
||||
assert is_supported_document_extension("malware.exe") is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"report.pdf",
|
||||
"doc.docx",
|
||||
"old.doc",
|
||||
"sheet.xlsx",
|
||||
"legacy.xls",
|
||||
"slides.pptx",
|
||||
"deck.ppt",
|
||||
"macro.docm",
|
||||
"macro.xlsm",
|
||||
"macro.pptm",
|
||||
"photo.png",
|
||||
"photo.jpg",
|
||||
"photo.jpeg",
|
||||
"scan.bmp",
|
||||
"scan.tiff",
|
||||
"scan.tif",
|
||||
"photo.webp",
|
||||
"anim.gif",
|
||||
"iphone.heic",
|
||||
"manual.rtf",
|
||||
"book.epub",
|
||||
"letter.odt",
|
||||
"data.ods",
|
||||
"presentation.odp",
|
||||
"inbox.eml",
|
||||
"outlook.msg",
|
||||
"korean.hwpx",
|
||||
"korean.hwp",
|
||||
"template.dot",
|
||||
"template.dotm",
|
||||
"template.pot",
|
||||
"template.potx",
|
||||
"binary.xlsb",
|
||||
"workspace.xlw",
|
||||
"vector.svg",
|
||||
"signature.p7s",
|
||||
],
|
||||
)
|
||||
def test_document_extensions_are_supported(filename):
|
||||
from app.utils.file_extensions import is_supported_document_extension
|
||||
|
||||
assert is_supported_document_extension(filename) is True, (
|
||||
f"{filename} should be supported"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"malware.exe",
|
||||
"archive.zip",
|
||||
"video.mov",
|
||||
"font.woff2",
|
||||
"model.blend",
|
||||
"random.xyz",
|
||||
"data.parquet",
|
||||
"package.deb",
|
||||
],
|
||||
)
|
||||
def test_non_document_extensions_are_not_supported(filename):
|
||||
from app.utils.file_extensions import is_supported_document_extension
|
||||
|
||||
assert is_supported_document_extension(filename) is False, (
|
||||
f"{filename} should NOT be supported"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-parser extension sets
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_union_equals_all_three_sets():
|
||||
from app.utils.file_extensions import (
|
||||
DOCLING_DOCUMENT_EXTENSIONS,
|
||||
DOCUMENT_EXTENSIONS,
|
||||
LLAMAPARSE_DOCUMENT_EXTENSIONS,
|
||||
UNSTRUCTURED_DOCUMENT_EXTENSIONS,
|
||||
)
|
||||
|
||||
expected = (
|
||||
DOCLING_DOCUMENT_EXTENSIONS
|
||||
| LLAMAPARSE_DOCUMENT_EXTENSIONS
|
||||
| UNSTRUCTURED_DOCUMENT_EXTENSIONS
|
||||
)
|
||||
assert expected == DOCUMENT_EXTENSIONS
|
||||
|
||||
|
||||
def test_get_extensions_for_docling():
|
||||
from app.utils.file_extensions import get_document_extensions_for_service
|
||||
|
||||
exts = get_document_extensions_for_service("DOCLING")
|
||||
assert ".pdf" in exts
|
||||
assert ".webp" in exts
|
||||
assert ".docx" in exts
|
||||
assert ".eml" not in exts
|
||||
assert ".docm" not in exts
|
||||
assert ".gif" not in exts
|
||||
assert ".heic" not in exts
|
||||
|
||||
|
||||
def test_get_extensions_for_llamacloud():
|
||||
from app.utils.file_extensions import get_document_extensions_for_service
|
||||
|
||||
exts = get_document_extensions_for_service("LLAMACLOUD")
|
||||
assert ".docm" in exts
|
||||
assert ".gif" in exts
|
||||
assert ".svg" in exts
|
||||
assert ".hwp" in exts
|
||||
assert ".eml" not in exts
|
||||
assert ".heic" not in exts
|
||||
|
||||
|
||||
def test_get_extensions_for_unstructured():
|
||||
from app.utils.file_extensions import get_document_extensions_for_service
|
||||
|
||||
exts = get_document_extensions_for_service("UNSTRUCTURED")
|
||||
assert ".eml" in exts
|
||||
assert ".heic" in exts
|
||||
assert ".p7s" in exts
|
||||
assert ".docm" not in exts
|
||||
assert ".gif" not in exts
|
||||
assert ".svg" not in exts
|
||||
|
||||
|
||||
def test_get_extensions_for_none_returns_union():
|
||||
from app.utils.file_extensions import (
|
||||
DOCUMENT_EXTENSIONS,
|
||||
get_document_extensions_for_service,
|
||||
)
|
||||
|
||||
assert get_document_extensions_for_service(None) == DOCUMENT_EXTENSIONS
|
||||
|
|
@ -19,6 +19,9 @@ files:
|
|||
- "!scripts"
|
||||
- "!release"
|
||||
extraResources:
|
||||
- from: assets/
|
||||
to: assets/
|
||||
filter: ["*.ico", "*.png", "*.icns"]
|
||||
- from: ../surfsense_web/.next/standalone/surfsense_web/
|
||||
to: standalone/
|
||||
filter:
|
||||
|
|
@ -58,7 +61,7 @@ win:
|
|||
icon: assets/icon.ico
|
||||
target:
|
||||
- target: nsis
|
||||
arch: [x64, arm64]
|
||||
arch: [x64]
|
||||
nsis:
|
||||
oneClick: false
|
||||
perMachine: false
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
"description": "SurfSense Desktop App",
|
||||
"main": "dist/main.js",
|
||||
"scripts": {
|
||||
"dev": "concurrently -k \"pnpm --dir ../surfsense_web dev\" \"wait-on http://localhost:3000 && electron .\"",
|
||||
"dev": "pnpm build && concurrently -k \"pnpm --dir ../surfsense_web dev\" \"wait-on http://localhost:3000 && electron .\"",
|
||||
"build": "node scripts/build-electron.mjs",
|
||||
"pack:dir": "pnpm build && electron-builder --dir --config electron-builder.yml",
|
||||
"dist": "pnpm build && electron-builder --config electron-builder.yml",
|
||||
|
|
|
|||
|
|
@ -32,4 +32,13 @@ export const IPC_CHANNELS = {
|
|||
FOLDER_SYNC_ACK_EVENTS: 'folder-sync:ack-events',
|
||||
BROWSE_FILES: 'browse:files',
|
||||
READ_LOCAL_FILES: 'browse:read-local-files',
|
||||
// Auth token sync across windows
|
||||
GET_AUTH_TOKENS: 'auth:get-tokens',
|
||||
SET_AUTH_TOKENS: 'auth:set-tokens',
|
||||
// Keyboard shortcut configuration
|
||||
GET_SHORTCUTS: 'shortcuts:get',
|
||||
SET_SHORTCUTS: 'shortcuts:set',
|
||||
// Active search space
|
||||
GET_ACTIVE_SEARCH_SPACE: 'search-space:get-active',
|
||||
SET_ACTIVE_SEARCH_SPACE: 'search-space:set-active',
|
||||
} as const;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,13 @@ import {
|
|||
browseFiles,
|
||||
readLocalFiles,
|
||||
} from '../modules/folder-watcher';
|
||||
import { getShortcuts, setShortcuts, type ShortcutConfig } from '../modules/shortcuts';
|
||||
import { getActiveSearchSpaceId, setActiveSearchSpaceId } from '../modules/active-search-space';
|
||||
import { reregisterQuickAsk } from '../modules/quick-ask';
|
||||
import { reregisterAutocomplete } from '../modules/autocomplete';
|
||||
import { reregisterGeneralAssist } from '../modules/tray';
|
||||
|
||||
let authTokens: { bearer: string; refresh: string } | null = null;
|
||||
|
||||
export function registerIpcHandlers(): void {
|
||||
ipcMain.on(IPC_CHANNELS.OPEN_EXTERNAL, (_event, url: string) => {
|
||||
|
|
@ -89,4 +96,28 @@ export function registerIpcHandlers(): void {
|
|||
ipcMain.handle(IPC_CHANNELS.READ_LOCAL_FILES, (_event, paths: string[]) =>
|
||||
readLocalFiles(paths)
|
||||
);
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.SET_AUTH_TOKENS, (_event, tokens: { bearer: string; refresh: string }) => {
|
||||
authTokens = tokens;
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.GET_AUTH_TOKENS, () => {
|
||||
return authTokens;
|
||||
});
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.GET_SHORTCUTS, () => getShortcuts());
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.GET_ACTIVE_SEARCH_SPACE, () => getActiveSearchSpaceId());
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.SET_ACTIVE_SEARCH_SPACE, (_event, id: string) =>
|
||||
setActiveSearchSpaceId(id)
|
||||
);
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.SET_SHORTCUTS, async (_event, config: Partial<ShortcutConfig>) => {
|
||||
const updated = await setShortcuts(config);
|
||||
if (config.generalAssist) await reregisterGeneralAssist();
|
||||
if (config.quickAsk) await reregisterQuickAsk();
|
||||
if (config.autocomplete) await reregisterAutocomplete();
|
||||
return updated;
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import { app, BrowserWindow } from 'electron';
|
||||
|
||||
let isQuitting = false;
|
||||
import { registerGlobalErrorHandlers, showErrorDialog } from './modules/errors';
|
||||
import { startNextServer } from './modules/server';
|
||||
import { createMainWindow } from './modules/window';
|
||||
import { createMainWindow, getMainWindow } from './modules/window';
|
||||
import { setupDeepLinks, handlePendingDeepLink } from './modules/deep-links';
|
||||
import { setupAutoUpdater } from './modules/auto-updater';
|
||||
import { setupMenu } from './modules/menu';
|
||||
|
|
@ -9,6 +11,7 @@ import { registerQuickAsk, unregisterQuickAsk } from './modules/quick-ask';
|
|||
import { registerAutocomplete, unregisterAutocomplete } from './modules/autocomplete';
|
||||
import { registerFolderWatcher, unregisterFolderWatcher } from './modules/folder-watcher';
|
||||
import { registerIpcHandlers } from './ipc/handlers';
|
||||
import { createTray, destroyTray } from './modules/tray';
|
||||
|
||||
registerGlobalErrorHandlers();
|
||||
|
||||
|
|
@ -28,29 +31,48 @@ app.whenReady().then(async () => {
|
|||
return;
|
||||
}
|
||||
|
||||
createMainWindow('/dashboard');
|
||||
registerQuickAsk();
|
||||
registerAutocomplete();
|
||||
await createTray();
|
||||
|
||||
const win = createMainWindow('/dashboard');
|
||||
|
||||
// Minimize to tray instead of closing the app
|
||||
win.on('close', (e) => {
|
||||
if (!isQuitting) {
|
||||
e.preventDefault();
|
||||
win.hide();
|
||||
}
|
||||
});
|
||||
|
||||
await registerQuickAsk();
|
||||
await registerAutocomplete();
|
||||
registerFolderWatcher();
|
||||
setupAutoUpdater();
|
||||
|
||||
handlePendingDeepLink();
|
||||
|
||||
app.on('activate', () => {
|
||||
if (BrowserWindow.getAllWindows().length === 0) {
|
||||
const mw = getMainWindow();
|
||||
if (!mw || mw.isDestroyed()) {
|
||||
createMainWindow('/dashboard');
|
||||
} else {
|
||||
mw.show();
|
||||
mw.focus();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Keep running in the background — the tray "Quit" calls app.exit()
|
||||
app.on('window-all-closed', () => {
|
||||
if (process.platform !== 'darwin') {
|
||||
app.quit();
|
||||
}
|
||||
// Do nothing: the app stays alive in the tray
|
||||
});
|
||||
|
||||
app.on('before-quit', () => {
|
||||
isQuitting = true;
|
||||
});
|
||||
|
||||
app.on('will-quit', () => {
|
||||
unregisterQuickAsk();
|
||||
unregisterAutocomplete();
|
||||
unregisterFolderWatcher();
|
||||
destroyTray();
|
||||
});
|
||||
|
|
|
|||
24
surfsense_desktop/src/modules/active-search-space.ts
Normal file
24
surfsense_desktop/src/modules/active-search-space.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
const STORE_KEY = 'activeSearchSpaceId';
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let store: any = null;
|
||||
|
||||
async function getStore() {
|
||||
if (!store) {
|
||||
const { default: Store } = await import('electron-store');
|
||||
store = new Store({
|
||||
name: 'active-search-space',
|
||||
defaults: { [STORE_KEY]: null as string | null },
|
||||
});
|
||||
}
|
||||
return store;
|
||||
}
|
||||
|
||||
export async function getActiveSearchSpaceId(): Promise<string | null> {
|
||||
const s = await getStore();
|
||||
return (s.get(STORE_KEY) as string | null) ?? null;
|
||||
}
|
||||
|
||||
export async function setActiveSearchSpaceId(id: string): Promise<void> {
|
||||
const s = await getStore();
|
||||
s.set(STORE_KEY, id);
|
||||
}
|
||||
|
|
@ -2,16 +2,15 @@ import { clipboard, globalShortcut, ipcMain, screen } from 'electron';
|
|||
import { IPC_CHANNELS } from '../../ipc/channels';
|
||||
import { getFrontmostApp, getWindowTitle, hasAccessibilityPermission, simulatePaste } from '../platform';
|
||||
import { hasScreenRecordingPermission, requestAccessibility, requestScreenRecording } from '../permissions';
|
||||
import { getMainWindow } from '../window';
|
||||
import { captureScreen } from './screenshot';
|
||||
import { createSuggestionWindow, destroySuggestion, getSuggestionWindow } from './suggestion-window';
|
||||
import { getShortcuts } from '../shortcuts';
|
||||
import { getActiveSearchSpaceId } from '../active-search-space';
|
||||
|
||||
const SHORTCUT = 'CommandOrControl+Shift+Space';
|
||||
|
||||
let currentShortcut = '';
|
||||
let autocompleteEnabled = true;
|
||||
let savedClipboard = '';
|
||||
let sourceApp = '';
|
||||
let lastSearchSpaceId: string | null = null;
|
||||
|
||||
function isSurfSenseWindow(): boolean {
|
||||
const app = getFrontmostApp();
|
||||
|
|
@ -37,21 +36,11 @@ async function triggerAutocomplete(): Promise<void> {
|
|||
return;
|
||||
}
|
||||
|
||||
const mainWin = getMainWindow();
|
||||
if (mainWin && !mainWin.isDestroyed()) {
|
||||
const mainUrl = mainWin.webContents.getURL();
|
||||
const match = mainUrl.match(/\/dashboard\/(\d+)/);
|
||||
if (match) {
|
||||
lastSearchSpaceId = match[1];
|
||||
}
|
||||
}
|
||||
|
||||
if (!lastSearchSpaceId) {
|
||||
console.warn('[autocomplete] No active search space. Open a search space first.');
|
||||
const searchSpaceId = await getActiveSearchSpaceId();
|
||||
if (!searchSpaceId) {
|
||||
console.warn('[autocomplete] No active search space. Select a search space first.');
|
||||
return;
|
||||
}
|
||||
|
||||
const searchSpaceId = lastSearchSpaceId;
|
||||
const cursor = screen.getCursorScreenPoint();
|
||||
const win = createSuggestionWindow(cursor.x, cursor.y);
|
||||
|
||||
|
|
@ -91,7 +80,12 @@ async function acceptAndInject(text: string): Promise<void> {
|
|||
}
|
||||
}
|
||||
|
||||
let ipcRegistered = false;
|
||||
|
||||
function registerIpcHandlers(): void {
|
||||
if (ipcRegistered) return;
|
||||
ipcRegistered = true;
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.ACCEPT_SUGGESTION, async (_event, text: string) => {
|
||||
await acceptAndInject(text);
|
||||
});
|
||||
|
|
@ -107,26 +101,39 @@ function registerIpcHandlers(): void {
|
|||
ipcMain.handle(IPC_CHANNELS.GET_AUTOCOMPLETE_ENABLED, () => autocompleteEnabled);
|
||||
}
|
||||
|
||||
export function registerAutocomplete(): void {
|
||||
registerIpcHandlers();
|
||||
function autocompleteHandler(): void {
|
||||
const sw = getSuggestionWindow();
|
||||
if (sw && !sw.isDestroyed()) {
|
||||
destroySuggestion();
|
||||
return;
|
||||
}
|
||||
triggerAutocomplete();
|
||||
}
|
||||
|
||||
const ok = globalShortcut.register(SHORTCUT, () => {
|
||||
const sw = getSuggestionWindow();
|
||||
if (sw && !sw.isDestroyed()) {
|
||||
destroySuggestion();
|
||||
return;
|
||||
}
|
||||
triggerAutocomplete();
|
||||
});
|
||||
async function registerShortcut(): Promise<void> {
|
||||
const shortcuts = await getShortcuts();
|
||||
currentShortcut = shortcuts.autocomplete;
|
||||
|
||||
const ok = globalShortcut.register(currentShortcut, autocompleteHandler);
|
||||
|
||||
if (!ok) {
|
||||
console.error(`[autocomplete] Failed to register shortcut ${SHORTCUT}`);
|
||||
console.error(`[autocomplete] Failed to register shortcut ${currentShortcut}`);
|
||||
} else {
|
||||
console.log(`[autocomplete] Registered shortcut ${SHORTCUT}`);
|
||||
console.log(`[autocomplete] Registered shortcut ${currentShortcut}`);
|
||||
}
|
||||
}
|
||||
|
||||
export async function registerAutocomplete(): Promise<void> {
|
||||
registerIpcHandlers();
|
||||
await registerShortcut();
|
||||
}
|
||||
|
||||
export function unregisterAutocomplete(): void {
|
||||
globalShortcut.unregister(SHORTCUT);
|
||||
if (currentShortcut) globalShortcut.unregister(currentShortcut);
|
||||
destroySuggestion();
|
||||
}
|
||||
|
||||
export async function reregisterAutocomplete(): Promise<void> {
|
||||
unregisterAutocomplete();
|
||||
await registerShortcut();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,16 +1,20 @@
|
|||
import { execSync } from 'child_process';
|
||||
import { systemPreferences } from 'electron';
|
||||
|
||||
const EXEC_OPTS = { windowsHide: true } as const;
|
||||
|
||||
export function getFrontmostApp(): string {
|
||||
try {
|
||||
if (process.platform === 'darwin') {
|
||||
return execSync(
|
||||
'osascript -e \'tell application "System Events" to get name of first application process whose frontmost is true\''
|
||||
'osascript -e \'tell application "System Events" to get name of first application process whose frontmost is true\'',
|
||||
EXEC_OPTS,
|
||||
).toString().trim();
|
||||
}
|
||||
if (process.platform === 'win32') {
|
||||
return execSync(
|
||||
'powershell -command "Add-Type \'using System; using System.Runtime.InteropServices; public class W { [DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow(); }\'; (Get-Process | Where-Object { $_.MainWindowHandle -eq [W]::GetForegroundWindow() }).ProcessName"'
|
||||
'powershell -NoProfile -NonInteractive -command "Add-Type \'using System; using System.Runtime.InteropServices; public class W { [DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow(); }\'; (Get-Process | Where-Object { $_.MainWindowHandle -eq [W]::GetForegroundWindow() }).ProcessName"',
|
||||
EXEC_OPTS,
|
||||
).toString().trim();
|
||||
}
|
||||
} catch {
|
||||
|
|
@ -21,9 +25,23 @@ export function getFrontmostApp(): string {
|
|||
|
||||
export function simulatePaste(): void {
|
||||
if (process.platform === 'darwin') {
|
||||
execSync('osascript -e \'tell application "System Events" to keystroke "v" using command down\'');
|
||||
execSync('osascript -e \'tell application "System Events" to keystroke "v" using command down\'', EXEC_OPTS);
|
||||
} else if (process.platform === 'win32') {
|
||||
execSync('powershell -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^v\')"');
|
||||
execSync('powershell -NoProfile -NonInteractive -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^v\')"', EXEC_OPTS);
|
||||
}
|
||||
}
|
||||
|
||||
export function simulateCopy(): boolean {
|
||||
try {
|
||||
if (process.platform === 'darwin') {
|
||||
execSync('osascript -e \'tell application "System Events" to keystroke "c" using command down\'', EXEC_OPTS);
|
||||
} else if (process.platform === 'win32') {
|
||||
execSync('powershell -NoProfile -NonInteractive -command "Add-Type -AssemblyName System.Windows.Forms; [System.Windows.Forms.SendKeys]::SendWait(\'^c\')"', EXEC_OPTS);
|
||||
}
|
||||
return true;
|
||||
} catch (err) {
|
||||
console.error('[simulateCopy] Failed:', err);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -36,12 +54,14 @@ export function getWindowTitle(): string {
|
|||
try {
|
||||
if (process.platform === 'darwin') {
|
||||
return execSync(
|
||||
'osascript -e \'tell application "System Events" to get title of front window of first application process whose frontmost is true\''
|
||||
'osascript -e \'tell application "System Events" to get title of front window of first application process whose frontmost is true\'',
|
||||
EXEC_OPTS,
|
||||
).toString().trim();
|
||||
}
|
||||
if (process.platform === 'win32') {
|
||||
return execSync(
|
||||
'powershell -command "(Get-Process | Where-Object { $_.MainWindowHandle -eq (Add-Type -MemberDefinition \'[DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow();\' -Name W -PassThru)::GetForegroundWindow() }).MainWindowTitle"'
|
||||
'powershell -NoProfile -NonInteractive -command "(Get-Process | Where-Object { $_.MainWindowHandle -eq (Add-Type -MemberDefinition \'[DllImport(\\\"user32.dll\\\")] public static extern IntPtr GetForegroundWindow();\' -Name W -PassThru)::GetForegroundWindow() }).MainWindowTitle"',
|
||||
EXEC_OPTS,
|
||||
).toString().trim();
|
||||
}
|
||||
} catch {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
import { BrowserWindow, clipboard, globalShortcut, ipcMain, screen, shell } from 'electron';
|
||||
import path from 'path';
|
||||
import { IPC_CHANNELS } from '../ipc/channels';
|
||||
import { checkAccessibilityPermission, getFrontmostApp, simulatePaste } from './platform';
|
||||
import { checkAccessibilityPermission, getFrontmostApp, simulateCopy, simulatePaste } from './platform';
|
||||
import { getServerPort } from './server';
|
||||
import { getShortcuts } from './shortcuts';
|
||||
import { getActiveSearchSpaceId } from './active-search-space';
|
||||
|
||||
const SHORTCUT = 'CommandOrControl+Option+S';
|
||||
let currentShortcut = '';
|
||||
let quickAskWindow: BrowserWindow | null = null;
|
||||
let pendingText = '';
|
||||
let pendingMode = '';
|
||||
let pendingSearchSpaceId: string | null = null;
|
||||
let sourceApp = '';
|
||||
let savedClipboard = '';
|
||||
|
||||
|
|
@ -52,7 +55,9 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow {
|
|||
skipTaskbar: true,
|
||||
});
|
||||
|
||||
quickAskWindow.loadURL(`http://localhost:${getServerPort()}/dashboard`);
|
||||
const spaceId = pendingSearchSpaceId;
|
||||
const route = spaceId ? `/dashboard/${spaceId}/new-chat` : '/dashboard';
|
||||
quickAskWindow.loadURL(`http://localhost:${getServerPort()}${route}`);
|
||||
|
||||
quickAskWindow.once('ready-to-show', () => {
|
||||
quickAskWindow?.show();
|
||||
|
|
@ -77,29 +82,53 @@ function createQuickAskWindow(x: number, y: number): BrowserWindow {
|
|||
return quickAskWindow;
|
||||
}
|
||||
|
||||
export function registerQuickAsk(): void {
|
||||
const ok = globalShortcut.register(SHORTCUT, () => {
|
||||
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
|
||||
destroyQuickAsk();
|
||||
return;
|
||||
}
|
||||
async function openQuickAsk(text: string): Promise<void> {
|
||||
pendingText = text;
|
||||
pendingSearchSpaceId = await getActiveSearchSpaceId();
|
||||
const cursor = screen.getCursorScreenPoint();
|
||||
const pos = clampToScreen(cursor.x, cursor.y, 450, 750);
|
||||
createQuickAskWindow(pos.x, pos.y);
|
||||
}
|
||||
|
||||
sourceApp = getFrontmostApp();
|
||||
savedClipboard = clipboard.readText();
|
||||
async function quickAskHandler(): Promise<void> {
|
||||
console.log('[quick-ask] Handler triggered');
|
||||
|
||||
const text = savedClipboard.trim();
|
||||
if (!text) return;
|
||||
|
||||
pendingText = text;
|
||||
const cursor = screen.getCursorScreenPoint();
|
||||
const pos = clampToScreen(cursor.x, cursor.y, 450, 750);
|
||||
createQuickAskWindow(pos.x, pos.y);
|
||||
});
|
||||
|
||||
if (!ok) {
|
||||
console.log(`Quick-ask: failed to register ${SHORTCUT}`);
|
||||
if (quickAskWindow && !quickAskWindow.isDestroyed()) {
|
||||
console.log('[quick-ask] Window already open, closing');
|
||||
destroyQuickAsk();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!checkAccessibilityPermission()) {
|
||||
console.log('[quick-ask] Accessibility permission denied');
|
||||
return;
|
||||
}
|
||||
|
||||
savedClipboard = clipboard.readText();
|
||||
console.log('[quick-ask] Saved clipboard length:', savedClipboard.length);
|
||||
|
||||
const copyOk = simulateCopy();
|
||||
console.log('[quick-ask] simulateCopy result:', copyOk);
|
||||
|
||||
await new Promise((r) => setTimeout(r, 300));
|
||||
|
||||
const afterCopy = clipboard.readText();
|
||||
const selected = afterCopy.trim();
|
||||
console.log('[quick-ask] Clipboard after copy length:', afterCopy.length, 'changed:', afterCopy !== savedClipboard);
|
||||
|
||||
const text = selected || savedClipboard.trim();
|
||||
|
||||
sourceApp = getFrontmostApp();
|
||||
console.log('[quick-ask] Source app:', sourceApp, '| Opening Quick Assist with', text.length, 'chars', selected ? '(selected)' : text ? '(clipboard fallback)' : '(empty)');
|
||||
openQuickAsk(text);
|
||||
}
|
||||
|
||||
let ipcRegistered = false;
|
||||
|
||||
function registerIpcHandlers(): void {
|
||||
if (ipcRegistered) return;
|
||||
ipcRegistered = true;
|
||||
|
||||
ipcMain.handle(IPC_CHANNELS.QUICK_ASK_TEXT, () => {
|
||||
const text = pendingText;
|
||||
pendingText = '';
|
||||
|
|
@ -136,6 +165,24 @@ export function registerQuickAsk(): void {
|
|||
});
|
||||
}
|
||||
|
||||
export function unregisterQuickAsk(): void {
|
||||
globalShortcut.unregister(SHORTCUT);
|
||||
async function registerShortcut(): Promise<void> {
|
||||
const shortcuts = await getShortcuts();
|
||||
currentShortcut = shortcuts.quickAsk;
|
||||
|
||||
const ok = globalShortcut.register(currentShortcut, () => { quickAskHandler(); });
|
||||
console.log(`[quick-ask] Register ${currentShortcut}: ${ok ? 'OK' : 'FAILED'}`);
|
||||
}
|
||||
|
||||
export async function registerQuickAsk(): Promise<void> {
|
||||
registerIpcHandlers();
|
||||
await registerShortcut();
|
||||
}
|
||||
|
||||
export function unregisterQuickAsk(): void {
|
||||
if (currentShortcut) globalShortcut.unregister(currentShortcut);
|
||||
}
|
||||
|
||||
export async function reregisterQuickAsk(): Promise<void> {
|
||||
unregisterQuickAsk();
|
||||
await registerShortcut();
|
||||
}
|
||||
|
|
|
|||
44
surfsense_desktop/src/modules/shortcuts.ts
Normal file
44
surfsense_desktop/src/modules/shortcuts.ts
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
export interface ShortcutConfig {
|
||||
generalAssist: string;
|
||||
quickAsk: string;
|
||||
autocomplete: string;
|
||||
}
|
||||
|
||||
const DEFAULTS: ShortcutConfig = {
|
||||
generalAssist: 'CommandOrControl+Shift+S',
|
||||
quickAsk: 'CommandOrControl+Alt+S',
|
||||
autocomplete: 'CommandOrControl+Shift+Space',
|
||||
};
|
||||
|
||||
const STORE_KEY = 'shortcuts';
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- lazily imported ESM module; matches folder-watcher.ts pattern
|
||||
let store: any = null;
|
||||
|
||||
async function getStore() {
|
||||
if (!store) {
|
||||
const { default: Store } = await import('electron-store');
|
||||
store = new Store({
|
||||
name: 'keyboard-shortcuts',
|
||||
defaults: { [STORE_KEY]: DEFAULTS },
|
||||
});
|
||||
}
|
||||
return store;
|
||||
}
|
||||
|
||||
export async function getShortcuts(): Promise<ShortcutConfig> {
|
||||
const s = await getStore();
|
||||
const stored = s.get(STORE_KEY) as Partial<ShortcutConfig> | undefined;
|
||||
return { ...DEFAULTS, ...stored };
|
||||
}
|
||||
|
||||
export async function setShortcuts(config: Partial<ShortcutConfig>): Promise<ShortcutConfig> {
|
||||
const s = await getStore();
|
||||
const current = (s.get(STORE_KEY) as ShortcutConfig) ?? DEFAULTS;
|
||||
const merged = { ...current, ...config };
|
||||
s.set(STORE_KEY, merged);
|
||||
return merged;
|
||||
}
|
||||
|
||||
export function getDefaults(): ShortcutConfig {
|
||||
return { ...DEFAULTS };
|
||||
}
|
||||
77
surfsense_desktop/src/modules/tray.ts
Normal file
77
surfsense_desktop/src/modules/tray.ts
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import { app, globalShortcut, Menu, nativeImage, Tray } from 'electron';
|
||||
import path from 'path';
|
||||
import { getMainWindow, createMainWindow } from './window';
|
||||
import { getShortcuts } from './shortcuts';
|
||||
|
||||
let tray: Tray | null = null;
|
||||
let currentShortcut: string | null = null;
|
||||
|
||||
function getTrayIcon(): nativeImage {
|
||||
const iconName = process.platform === 'win32' ? 'icon.ico' : 'icon.png';
|
||||
const iconPath = app.isPackaged
|
||||
? path.join(process.resourcesPath, 'assets', iconName)
|
||||
: path.join(__dirname, '..', 'assets', iconName);
|
||||
const img = nativeImage.createFromPath(iconPath);
|
||||
return img.resize({ width: 16, height: 16 });
|
||||
}
|
||||
|
||||
function showMainWindow(): void {
|
||||
let win = getMainWindow();
|
||||
if (!win || win.isDestroyed()) {
|
||||
win = createMainWindow('/dashboard');
|
||||
} else {
|
||||
win.show();
|
||||
win.focus();
|
||||
}
|
||||
}
|
||||
|
||||
function registerShortcut(accelerator: string): void {
|
||||
if (currentShortcut) {
|
||||
globalShortcut.unregister(currentShortcut);
|
||||
currentShortcut = null;
|
||||
}
|
||||
if (!accelerator) return;
|
||||
try {
|
||||
const ok = globalShortcut.register(accelerator, showMainWindow);
|
||||
if (ok) {
|
||||
currentShortcut = accelerator;
|
||||
} else {
|
||||
console.warn(`[tray] Failed to register General Assist shortcut: ${accelerator}`);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(`[tray] Error registering General Assist shortcut:`, err);
|
||||
}
|
||||
}
|
||||
|
||||
export async function createTray(): Promise<void> {
|
||||
if (tray) return;
|
||||
|
||||
tray = new Tray(getTrayIcon());
|
||||
tray.setToolTip('SurfSense');
|
||||
|
||||
const contextMenu = Menu.buildFromTemplate([
|
||||
{ label: 'Open SurfSense', click: showMainWindow },
|
||||
{ type: 'separator' },
|
||||
{ label: 'Quit', click: () => { app.exit(0); } },
|
||||
]);
|
||||
|
||||
tray.setContextMenu(contextMenu);
|
||||
tray.on('double-click', showMainWindow);
|
||||
|
||||
const shortcuts = await getShortcuts();
|
||||
registerShortcut(shortcuts.generalAssist);
|
||||
}
|
||||
|
||||
export async function reregisterGeneralAssist(): Promise<void> {
|
||||
const shortcuts = await getShortcuts();
|
||||
registerShortcut(shortcuts.generalAssist);
|
||||
}
|
||||
|
||||
export function destroyTray(): void {
|
||||
if (currentShortcut) {
|
||||
globalShortcut.unregister(currentShortcut);
|
||||
currentShortcut = null;
|
||||
}
|
||||
tray?.destroy();
|
||||
tray = null;
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ import { app, BrowserWindow, shell, session } from 'electron';
|
|||
import path from 'path';
|
||||
import { showErrorDialog } from './errors';
|
||||
import { getServerPort } from './server';
|
||||
import { setActiveSearchSpaceId } from './active-search-space';
|
||||
|
||||
const isDev = !app.isPackaged;
|
||||
const HOSTED_FRONTEND_URL = process.env.HOSTED_FRONTEND_URL as string;
|
||||
|
|
@ -55,6 +56,16 @@ export function createMainWindow(initialPath = '/dashboard'): BrowserWindow {
|
|||
showErrorDialog('Page failed to load', new Error(`${errorDescription} (${errorCode})\n${validatedURL}`));
|
||||
});
|
||||
|
||||
// Auto-sync active search space from URL navigation
|
||||
const syncSearchSpace = (url: string) => {
|
||||
const match = url.match(/\/dashboard\/(\d+)/);
|
||||
if (match) {
|
||||
setActiveSearchSpaceId(match[1]);
|
||||
}
|
||||
};
|
||||
mainWindow.webContents.on('did-navigate', (_event, url) => syncSearchSpace(url));
|
||||
mainWindow.webContents.on('did-navigate-in-page', (_event, url) => syncSearchSpace(url));
|
||||
|
||||
if (isDev) {
|
||||
mainWindow.webContents.openDevTools();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,4 +68,19 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
|||
// Browse files via native dialog
|
||||
browseFiles: () => ipcRenderer.invoke(IPC_CHANNELS.BROWSE_FILES),
|
||||
readLocalFiles: (paths: string[]) => ipcRenderer.invoke(IPC_CHANNELS.READ_LOCAL_FILES, paths),
|
||||
|
||||
// Auth token sync across windows
|
||||
getAuthTokens: () => ipcRenderer.invoke(IPC_CHANNELS.GET_AUTH_TOKENS),
|
||||
setAuthTokens: (bearer: string, refresh: string) =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.SET_AUTH_TOKENS, { bearer, refresh }),
|
||||
|
||||
// Keyboard shortcut configuration
|
||||
getShortcuts: () => ipcRenderer.invoke(IPC_CHANNELS.GET_SHORTCUTS),
|
||||
setShortcuts: (config: Record<string, string>) =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.SET_SHORTCUTS, config),
|
||||
|
||||
// Active search space
|
||||
getActiveSearchSpace: () => ipcRenderer.invoke(IPC_CHANNELS.GET_ACTIVE_SEARCH_SPACE),
|
||||
setActiveSearchSpace: (id: string) =>
|
||||
ipcRenderer.invoke(IPC_CHANNELS.SET_ACTIVE_SEARCH_SPACE, id),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import { OnboardingTour } from "@/components/onboarding-tour";
|
|||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { useFolderSync } from "@/hooks/use-folder-sync";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
|
||||
export function DashboardClientLayout({
|
||||
children,
|
||||
|
|
@ -139,6 +140,8 @@ export function DashboardClientLayout({
|
|||
refetchPreferences,
|
||||
]);
|
||||
|
||||
const electronAPI = useElectronAPI();
|
||||
|
||||
useEffect(() => {
|
||||
const activeSeacrhSpaceId =
|
||||
typeof search_space_id === "string"
|
||||
|
|
@ -148,7 +151,16 @@ export function DashboardClientLayout({
|
|||
: "";
|
||||
if (!activeSeacrhSpaceId) return;
|
||||
setActiveSearchSpaceIdState(activeSeacrhSpaceId);
|
||||
}, [search_space_id, setActiveSearchSpaceIdState]);
|
||||
|
||||
// Sync to Electron store if stored value is null (first navigation)
|
||||
if (electronAPI?.setActiveSearchSpace) {
|
||||
electronAPI.getActiveSearchSpace?.().then((stored) => {
|
||||
if (!stored) {
|
||||
electronAPI.setActiveSearchSpace!(activeSeacrhSpaceId);
|
||||
}
|
||||
}).catch(() => {});
|
||||
}
|
||||
}, [search_space_id, setActiveSearchSpaceIdState, electronAPI]);
|
||||
|
||||
// Determine if we should show loading
|
||||
const shouldShowLoading =
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { Button } from "@/components/ui/button";
|
|||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
||||
import { getDocumentTypeIcon, getDocumentTypeLabel } from "./DocumentTypeIcon";
|
||||
|
|
@ -63,109 +64,113 @@ export function DocumentsFilters({
|
|||
return (
|
||||
<div className="flex select-none">
|
||||
<div className="flex items-center gap-2 w-full">
|
||||
{/* Type Filter */}
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
className="h-9 w-9 shrink-0 border-dashed border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar"
|
||||
>
|
||||
<ListFilter size={14} />
|
||||
{activeTypes.length > 0 && (
|
||||
<span className="absolute -top-1 -right-1 flex h-4 w-4 items-center justify-center rounded-full bg-primary text-[9px] font-medium text-primary-foreground">
|
||||
{activeTypes.length}
|
||||
</span>
|
||||
)}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-56 md:w-52 !p-0 overflow-hidden" align="end">
|
||||
<div>
|
||||
{/* Search input */}
|
||||
<div className="p-2">
|
||||
<div className="relative">
|
||||
<Search className="absolute left-0.5 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
|
||||
<Input
|
||||
placeholder="Search types"
|
||||
value={typeSearchQuery}
|
||||
onChange={(e) => setTypeSearchQuery(e.target.value)}
|
||||
className="h-6 pl-6 text-sm bg-transparent border-0 shadow-none"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{/* Filter + New Folder Toggle Group */}
|
||||
<ToggleGroup type="multiple" variant="outline" value={[]} className="overflow-visible">
|
||||
{onCreateFolder && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<ToggleGroupItem
|
||||
value="folder"
|
||||
className="h-9 w-9 shrink-0 border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar"
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
onCreateFolder();
|
||||
}}
|
||||
>
|
||||
<FolderPlus size={14} />
|
||||
</ToggleGroupItem>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>New folder</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
<div
|
||||
className="max-h-[300px] overflow-y-auto overflow-x-hidden py-1.5 px-1.5"
|
||||
onScroll={handleScroll}
|
||||
style={{
|
||||
maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`,
|
||||
WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`,
|
||||
}}
|
||||
>
|
||||
{filteredTypes.length === 0 ? (
|
||||
<div className="py-6 text-center text-sm text-muted-foreground">
|
||||
No types found
|
||||
</div>
|
||||
) : (
|
||||
filteredTypes.map((value: DocumentTypeEnum, i) => (
|
||||
<div
|
||||
role="option"
|
||||
aria-selected={activeTypes.includes(value)}
|
||||
tabIndex={0}
|
||||
key={value}
|
||||
className="flex w-full items-center gap-2.5 py-2 px-3 rounded-md hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-colors cursor-pointer text-left"
|
||||
onClick={() => onToggleType(value, !activeTypes.includes(value))}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
onToggleType(value, !activeTypes.includes(value));
|
||||
}
|
||||
}}
|
||||
>
|
||||
{/* Icon */}
|
||||
<div className="flex h-7 w-7 shrink-0 items-center justify-center rounded-md bg-muted/50 text-foreground/80">
|
||||
{getDocumentTypeIcon(value, "h-4 w-4")}
|
||||
</div>
|
||||
{/* Text content */}
|
||||
<div className="flex flex-col min-w-0 flex-1 gap-0.5">
|
||||
<span className="text-[13px] font-medium text-foreground truncate leading-tight">
|
||||
{getDocumentTypeLabel(value)}
|
||||
</span>
|
||||
<span className="text-[11px] text-muted-foreground leading-tight">
|
||||
{typeCounts.get(value)} document
|
||||
{(typeCounts.get(value) ?? 0) !== 1 ? "s" : ""}
|
||||
</span>
|
||||
</div>
|
||||
{/* Checkbox */}
|
||||
<Checkbox
|
||||
id={`${id}-${i}`}
|
||||
checked={activeTypes.includes(value)}
|
||||
onCheckedChange={(checked: boolean) => onToggleType(value, !!checked)}
|
||||
className="h-4 w-4 shrink-0 rounded border-muted-foreground/30 data-[state=checked]:bg-primary data-[state=checked]:border-primary"
|
||||
/>
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
{activeTypes.length > 0 && (
|
||||
<div className="px-3 pt-1.5 pb-1.5 border-t border-border dark:border-neutral-700">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="w-full h-7 text-[11px] text-muted-foreground hover:text-foreground hover:bg-neutral-200 dark:hover:bg-neutral-700"
|
||||
onClick={() => {
|
||||
activeTypes.forEach((t) => {
|
||||
onToggleType(t, false);
|
||||
});
|
||||
}}
|
||||
<Popover>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<PopoverTrigger asChild>
|
||||
<ToggleGroupItem
|
||||
value="filter"
|
||||
className="relative h-9 w-9 shrink-0 border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar overflow-visible"
|
||||
>
|
||||
Clear filters
|
||||
</Button>
|
||||
<ListFilter size={14} />
|
||||
{activeTypes.length > 0 && (
|
||||
<span className="absolute -top-1 -right-1 flex h-4 w-4 items-center justify-center rounded-full bg-sidebar-border text-[9px] font-medium text-sidebar-foreground">
|
||||
{activeTypes.length}
|
||||
</span>
|
||||
)}
|
||||
</ToggleGroupItem>
|
||||
</PopoverTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Filter by type</TooltipContent>
|
||||
</Tooltip>
|
||||
<PopoverContent className="w-56 md:w-52 !p-0 overflow-hidden" align="start">
|
||||
<div>
|
||||
<div className="p-2">
|
||||
<div className="relative">
|
||||
<Search className="absolute left-0.5 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
|
||||
<Input
|
||||
placeholder="Search types"
|
||||
value={typeSearchQuery}
|
||||
onChange={(e) => setTypeSearchQuery(e.target.value)}
|
||||
className="h-6 pl-6 text-sm bg-transparent border-0 shadow-none"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
|
||||
<div
|
||||
className="max-h-[300px] overflow-y-auto overflow-x-hidden py-1.5 px-1.5"
|
||||
onScroll={handleScroll}
|
||||
style={{
|
||||
maskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`,
|
||||
WebkitMaskImage: `linear-gradient(to bottom, ${scrollPos === "top" ? "black" : "transparent"}, black 16px, black calc(100% - 16px), ${scrollPos === "bottom" ? "black" : "transparent"})`,
|
||||
}}
|
||||
>
|
||||
{filteredTypes.length === 0 ? (
|
||||
<div className="py-6 text-center text-sm text-muted-foreground">
|
||||
No types found
|
||||
</div>
|
||||
) : (
|
||||
filteredTypes.map((value: DocumentTypeEnum, i) => (
|
||||
<div
|
||||
role="option"
|
||||
aria-selected={activeTypes.includes(value)}
|
||||
tabIndex={0}
|
||||
key={value}
|
||||
className="flex w-full items-center gap-2.5 py-2 px-3 rounded-md hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-colors cursor-pointer text-left"
|
||||
onClick={() => onToggleType(value, !activeTypes.includes(value))}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
onToggleType(value, !activeTypes.includes(value));
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex h-7 w-7 shrink-0 items-center justify-center rounded-md bg-muted/50 text-foreground/80">
|
||||
{getDocumentTypeIcon(value, "h-4 w-4")}
|
||||
</div>
|
||||
<div className="flex flex-col min-w-0 flex-1 gap-0.5">
|
||||
<span className="text-[13px] font-medium text-foreground truncate leading-tight">
|
||||
{getDocumentTypeLabel(value)}
|
||||
</span>
|
||||
<span className="text-[11px] text-muted-foreground leading-tight">
|
||||
{typeCounts.get(value)} document
|
||||
{(typeCounts.get(value) ?? 0) !== 1 ? "s" : ""}
|
||||
</span>
|
||||
</div>
|
||||
<Checkbox
|
||||
id={`${id}-${i}`}
|
||||
checked={activeTypes.includes(value)}
|
||||
onCheckedChange={(checked: boolean) => onToggleType(value, !!checked)}
|
||||
className="h-4 w-4 shrink-0 rounded border-muted-foreground/30 data-[state=checked]:bg-primary data-[state=checked]:border-primary"
|
||||
/>
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</ToggleGroup>
|
||||
|
||||
{/* Search Input */}
|
||||
<div className="relative flex-1 min-w-0">
|
||||
|
|
@ -197,23 +202,6 @@ export function DocumentsFilters({
|
|||
)}
|
||||
</div>
|
||||
|
||||
{/* New Folder Button */}
|
||||
{onCreateFolder && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
className="h-9 w-9 shrink-0 border-dashed border-sidebar-border text-sidebar-foreground/60 hover:text-sidebar-foreground hover:border-sidebar-border bg-sidebar"
|
||||
onClick={onCreateFolder}
|
||||
>
|
||||
<FolderPlus size={14} />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>New folder</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{/* Upload Button */}
|
||||
<Button
|
||||
data-joyride="upload-button"
|
||||
|
|
|
|||
|
|
@ -1,30 +1,65 @@
|
|||
"use client";
|
||||
|
||||
import { BrainCog, Rocket, Zap } from "lucide-react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import type { SearchSpace } from "@/contracts/types/search-space.types";
|
||||
|
||||
export function DesktopContent() {
|
||||
const [isElectron, setIsElectron] = useState(false);
|
||||
const api = useElectronAPI();
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [enabled, setEnabled] = useState(true);
|
||||
|
||||
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
|
||||
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
|
||||
|
||||
const [searchSpaces, setSearchSpaces] = useState<SearchSpace[]>([]);
|
||||
const [activeSpaceId, setActiveSpaceId] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI) {
|
||||
if (!api) {
|
||||
setLoading(false);
|
||||
setShortcutsLoaded(true);
|
||||
return;
|
||||
}
|
||||
setIsElectron(true);
|
||||
|
||||
window.electronAPI.getAutocompleteEnabled().then((val) => {
|
||||
setEnabled(val);
|
||||
setLoading(false);
|
||||
});
|
||||
}, []);
|
||||
let mounted = true;
|
||||
|
||||
if (!isElectron) {
|
||||
Promise.all([
|
||||
api.getAutocompleteEnabled(),
|
||||
api.getShortcuts?.() ?? Promise.resolve(null),
|
||||
api.getActiveSearchSpace?.() ?? Promise.resolve(null),
|
||||
searchSpacesApiService.getSearchSpaces(),
|
||||
])
|
||||
.then(([autoEnabled, config, spaceId, spaces]) => {
|
||||
if (!mounted) return;
|
||||
setEnabled(autoEnabled);
|
||||
if (config) setShortcuts(config);
|
||||
setActiveSpaceId(spaceId);
|
||||
if (spaces) setSearchSpaces(spaces);
|
||||
setLoading(false);
|
||||
setShortcutsLoaded(true);
|
||||
})
|
||||
.catch(() => {
|
||||
if (!mounted) return;
|
||||
setLoading(false);
|
||||
setShortcutsLoaded(true);
|
||||
});
|
||||
|
||||
return () => {
|
||||
mounted = false;
|
||||
};
|
||||
}, [api]);
|
||||
|
||||
if (!api) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-12 text-center">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
|
|
@ -44,14 +79,114 @@ export function DesktopContent() {
|
|||
|
||||
const handleToggle = async (checked: boolean) => {
|
||||
setEnabled(checked);
|
||||
await window.electronAPI!.setAutocompleteEnabled(checked);
|
||||
await api.setAutocompleteEnabled(checked);
|
||||
};
|
||||
|
||||
const updateShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete", accelerator: string) => {
|
||||
setShortcuts((prev) => {
|
||||
const updated = { ...prev, [key]: accelerator };
|
||||
api.setShortcuts?.({ [key]: accelerator }).catch(() => {
|
||||
toast.error("Failed to update shortcut");
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
toast.success("Shortcut updated");
|
||||
};
|
||||
|
||||
const resetShortcut = (key: "generalAssist" | "quickAsk" | "autocomplete") => {
|
||||
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
|
||||
};
|
||||
|
||||
const handleSearchSpaceChange = (value: string) => {
|
||||
setActiveSpaceId(value);
|
||||
api.setActiveSearchSpace?.(value);
|
||||
toast.success("Default search space updated");
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4 md:space-y-6">
|
||||
{/* Default Search Space */}
|
||||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
<CardTitle className="text-base md:text-lg">Autocomplete</CardTitle>
|
||||
<CardTitle className="text-base md:text-lg">Default Search Space</CardTitle>
|
||||
<CardDescription className="text-xs md:text-sm">
|
||||
Choose which search space General Assist, Quick Assist, and Extreme Assist operate against.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
|
||||
{searchSpaces.length > 0 ? (
|
||||
<Select value={activeSpaceId ?? undefined} onValueChange={handleSearchSpaceChange}>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="Select a search space" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{searchSpaces.map((space) => (
|
||||
<SelectItem key={space.id} value={String(space.id)}>
|
||||
{space.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
) : (
|
||||
<p className="text-sm text-muted-foreground">No search spaces found. Create one first.</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Keyboard Shortcuts */}
|
||||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
<CardTitle className="text-base md:text-lg">Keyboard Shortcuts</CardTitle>
|
||||
<CardDescription className="text-xs md:text-sm">
|
||||
Customize the global keyboard shortcuts for desktop features.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="px-3 md:px-6 pb-3 md:pb-6">
|
||||
{shortcutsLoaded ? (
|
||||
<div className="flex flex-col gap-3">
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.generalAssist}
|
||||
onChange={(accel) => updateShortcut("generalAssist", accel)}
|
||||
onReset={() => resetShortcut("generalAssist")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
|
||||
label="General Assist"
|
||||
description="Launch SurfSense instantly from any application"
|
||||
icon={Rocket}
|
||||
/>
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.quickAsk}
|
||||
onChange={(accel) => updateShortcut("quickAsk", accel)}
|
||||
onReset={() => resetShortcut("quickAsk")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
|
||||
label="Quick Assist"
|
||||
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
|
||||
icon={Zap}
|
||||
/>
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.autocomplete}
|
||||
onChange={(accel) => updateShortcut("autocomplete", accel)}
|
||||
onReset={() => resetShortcut("autocomplete")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
|
||||
label="Extreme Assist"
|
||||
description="AI drafts text using your screen context and knowledge base"
|
||||
icon={BrainCog}
|
||||
/>
|
||||
<p className="text-[11px] text-muted-foreground">
|
||||
Click a shortcut and press a new key combination to change it.
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex justify-center py-4">
|
||||
<Spinner size="sm" />
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Extreme Assist Toggle */}
|
||||
<Card>
|
||||
<CardHeader className="px-3 md:px-6 pt-3 md:pt-6 pb-2 md:pb-3">
|
||||
<CardTitle className="text-base md:text-lg">Extreme Assist</CardTitle>
|
||||
<CardDescription className="text-xs md:text-sm">
|
||||
Get inline writing suggestions powered by your knowledge base as you type in any app.
|
||||
</CardDescription>
|
||||
|
|
@ -60,17 +195,13 @@ export function DesktopContent() {
|
|||
<div className="flex items-center justify-between rounded-lg border p-4">
|
||||
<div className="space-y-0.5">
|
||||
<Label htmlFor="autocomplete-toggle" className="text-sm font-medium cursor-pointer">
|
||||
Enable autocomplete
|
||||
Enable Extreme Assist
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Show suggestions while typing in other applications.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="autocomplete-toggle"
|
||||
checked={enabled}
|
||||
onCheckedChange={handleToggle}
|
||||
/>
|
||||
<Switch id="autocomplete-toggle" checked={enabled} onCheckedChange={handleToggle} />
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import { useEffect, useState } from "react";
|
||||
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { getBearerToken, redirectToLogin } from "@/lib/auth-utils";
|
||||
import { ensureTokensFromElectron, getBearerToken, redirectToLogin } from "@/lib/auth-utils";
|
||||
import { queryClient } from "@/lib/query-client/client";
|
||||
|
||||
interface DashboardLayoutProps {
|
||||
|
|
@ -17,15 +17,20 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
|
|||
useGlobalLoadingEffect(isCheckingAuth);
|
||||
|
||||
useEffect(() => {
|
||||
// Check if user is authenticated
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
// Save current path and redirect to login
|
||||
redirectToLogin();
|
||||
return;
|
||||
async function checkAuth() {
|
||||
let token = getBearerToken();
|
||||
if (!token) {
|
||||
const synced = await ensureTokensFromElectron();
|
||||
if (synced) token = getBearerToken();
|
||||
}
|
||||
if (!token) {
|
||||
redirectToLogin();
|
||||
return;
|
||||
}
|
||||
queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] });
|
||||
setIsCheckingAuth(false);
|
||||
}
|
||||
queryClient.invalidateQueries({ queryKey: [...USER_QUERY_KEY] });
|
||||
setIsCheckingAuth(false);
|
||||
checkAuth();
|
||||
}, []);
|
||||
|
||||
// Return null while loading - the global provider handles the loading UI
|
||||
|
|
|
|||
284
surfsense_web/app/desktop/login/page.tsx
Normal file
284
surfsense_web/app/desktop/login/page.tsx
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
"use client";
|
||||
|
||||
import { IconBrandGoogleFilled } from "@tabler/icons-react";
|
||||
import { useAtom } from "jotai";
|
||||
import { BrainCog, Eye, EyeOff, Rocket, Zap } from "lucide-react";
|
||||
import Image from "next/image";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { loginMutationAtom } from "@/atoms/auth/auth-mutation.atoms";
|
||||
import { DEFAULT_SHORTCUTS, ShortcutRecorder } from "@/components/desktop/shortcut-recorder";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import { setBearerToken } from "@/lib/auth-utils";
|
||||
import { AUTH_TYPE, BACKEND_URL } from "@/lib/env-config";
|
||||
|
||||
const isGoogleAuth = AUTH_TYPE === "GOOGLE";
|
||||
|
||||
export default function DesktopLoginPage() {
|
||||
const router = useRouter();
|
||||
const api = useElectronAPI();
|
||||
const [{ mutateAsync: login, isPending: isLoggingIn }] = useAtom(loginMutationAtom);
|
||||
|
||||
const [email, setEmail] = useState("");
|
||||
const [password, setPassword] = useState("");
|
||||
const [showPassword, setShowPassword] = useState(false);
|
||||
const [loginError, setLoginError] = useState<string | null>(null);
|
||||
|
||||
const [shortcuts, setShortcuts] = useState(DEFAULT_SHORTCUTS);
|
||||
const [shortcutsLoaded, setShortcutsLoaded] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!api?.getShortcuts) {
|
||||
setShortcutsLoaded(true);
|
||||
return;
|
||||
}
|
||||
api
|
||||
.getShortcuts()
|
||||
.then((config) => {
|
||||
if (config) setShortcuts(config);
|
||||
setShortcutsLoaded(true);
|
||||
})
|
||||
.catch(() => setShortcutsLoaded(true));
|
||||
}, [api]);
|
||||
|
||||
const updateShortcut = useCallback(
|
||||
(key: "generalAssist" | "quickAsk" | "autocomplete", accelerator: string) => {
|
||||
setShortcuts((prev) => {
|
||||
const updated = { ...prev, [key]: accelerator };
|
||||
api?.setShortcuts?.({ [key]: accelerator }).catch(() => {
|
||||
toast.error("Failed to update shortcut");
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
toast.success("Shortcut updated");
|
||||
},
|
||||
[api]
|
||||
);
|
||||
|
||||
const resetShortcut = useCallback(
|
||||
(key: "generalAssist" | "quickAsk" | "autocomplete") => {
|
||||
updateShortcut(key, DEFAULT_SHORTCUTS[key]);
|
||||
},
|
||||
[updateShortcut]
|
||||
);
|
||||
|
||||
const handleGoogleLogin = () => {
|
||||
window.location.href = `${BACKEND_URL}/auth/google/authorize-redirect`;
|
||||
};
|
||||
|
||||
const autoSetSearchSpace = async () => {
|
||||
try {
|
||||
const stored = await api?.getActiveSearchSpace?.();
|
||||
if (stored) return;
|
||||
const spaces = await searchSpacesApiService.getSearchSpaces();
|
||||
if (spaces?.length) {
|
||||
await api?.setActiveSearchSpace?.(String(spaces[0].id));
|
||||
}
|
||||
} catch {
|
||||
// non-critical — dashboard-sync will catch it later
|
||||
}
|
||||
};
|
||||
|
||||
const handleLocalLogin = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setLoginError(null);
|
||||
|
||||
try {
|
||||
const data = await login({
|
||||
username: email,
|
||||
password,
|
||||
grant_type: "password",
|
||||
});
|
||||
|
||||
if (typeof window !== "undefined") {
|
||||
sessionStorage.setItem("login_success_tracked", "true");
|
||||
}
|
||||
|
||||
setBearerToken(data.access_token);
|
||||
await autoSetSearchSpace();
|
||||
|
||||
setTimeout(() => {
|
||||
router.push(`/auth/callback?token=${data.access_token}`);
|
||||
}, 300);
|
||||
} catch (err) {
|
||||
if (err instanceof Error) {
|
||||
setLoginError(err.message);
|
||||
} else {
|
||||
setLoginError("Login failed. Please check your credentials.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="relative flex min-h-svh items-center justify-center bg-background p-4 sm:p-6">
|
||||
{/* Subtle radial glow */}
|
||||
<div className="pointer-events-none fixed inset-0 overflow-hidden">
|
||||
<div
|
||||
className="absolute -top-1/2 left-1/2 size-[800px] -translate-x-1/2 rounded-full opacity-[0.03]"
|
||||
style={{
|
||||
background: "radial-gradient(circle, hsl(var(--primary)) 0%, transparent 70%)",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="relative flex w-full max-w-md flex-col overflow-hidden rounded-xl border bg-card shadow-lg">
|
||||
{/* Header */}
|
||||
<div className="flex flex-col items-center px-6 pt-6 pb-2 text-center">
|
||||
<Image
|
||||
src="/icon-128.svg"
|
||||
className="select-none dark:invert size-12 rounded-lg mb-3"
|
||||
alt="SurfSense"
|
||||
width={48}
|
||||
height={48}
|
||||
priority
|
||||
/>
|
||||
<h1 className="text-lg font-semibold tracking-tight">
|
||||
Welcome to SurfSense Desktop
|
||||
</h1>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Configure shortcuts, then sign in to get started.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Scrollable content */}
|
||||
<div className="flex-1 overflow-y-auto px-6 py-4">
|
||||
<div className="flex flex-col gap-5">
|
||||
{/* ---- Shortcuts ---- */}
|
||||
{shortcutsLoaded ? (
|
||||
<div className="flex flex-col gap-2">
|
||||
<p className="text-xs font-medium uppercase tracking-wider text-muted-foreground">
|
||||
Keyboard Shortcuts
|
||||
</p>
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.generalAssist}
|
||||
onChange={(accel) => updateShortcut("generalAssist", accel)}
|
||||
onReset={() => resetShortcut("generalAssist")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.generalAssist}
|
||||
label="General Assist"
|
||||
description="Launch SurfSense instantly from any application"
|
||||
icon={Rocket}
|
||||
/>
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.quickAsk}
|
||||
onChange={(accel) => updateShortcut("quickAsk", accel)}
|
||||
onReset={() => resetShortcut("quickAsk")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.quickAsk}
|
||||
label="Quick Assist"
|
||||
description="Select text anywhere, then ask AI to explain, rewrite, or act on it"
|
||||
icon={Zap}
|
||||
/>
|
||||
<ShortcutRecorder
|
||||
value={shortcuts.autocomplete}
|
||||
onChange={(accel) => updateShortcut("autocomplete", accel)}
|
||||
onReset={() => resetShortcut("autocomplete")}
|
||||
defaultValue={DEFAULT_SHORTCUTS.autocomplete}
|
||||
label="Extreme Assist"
|
||||
description="AI drafts text using your screen context and knowledge base"
|
||||
icon={BrainCog}
|
||||
/>
|
||||
</div>
|
||||
<p className="text-[11px] text-muted-foreground text-center mt-1">
|
||||
Click a shortcut and press a new key combination to change it.
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex justify-center py-6">
|
||||
<Spinner size="sm" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* ---- Auth ---- */}
|
||||
<div className="flex flex-col gap-3">
|
||||
<p className="text-xs font-medium uppercase tracking-wider text-muted-foreground">
|
||||
Sign In
|
||||
</p>
|
||||
|
||||
{isGoogleAuth ? (
|
||||
<Button variant="outline" className="w-full gap-2 h-10" onClick={handleGoogleLogin}>
|
||||
<IconBrandGoogleFilled className="size-4" />
|
||||
Continue with Google
|
||||
</Button>
|
||||
) : (
|
||||
<form onSubmit={handleLocalLogin} className="flex flex-col gap-3">
|
||||
{loginError && (
|
||||
<div className="rounded-md border border-destructive/20 bg-destructive/10 px-3 py-2 text-sm text-destructive">
|
||||
{loginError}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label htmlFor="email" className="text-xs">
|
||||
Email
|
||||
</Label>
|
||||
<Input
|
||||
id="email"
|
||||
type="email"
|
||||
placeholder="you@example.com"
|
||||
required
|
||||
value={email}
|
||||
onChange={(e) => setEmail(e.target.value)}
|
||||
disabled={isLoggingIn}
|
||||
autoFocus
|
||||
className="h-9"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label htmlFor="password" className="text-xs">
|
||||
Password
|
||||
</Label>
|
||||
<div className="relative">
|
||||
<Input
|
||||
id="password"
|
||||
type={showPassword ? "text" : "password"}
|
||||
placeholder="Enter your password"
|
||||
required
|
||||
value={password}
|
||||
onChange={(e) => setPassword(e.target.value)}
|
||||
disabled={isLoggingIn}
|
||||
className="h-9 pr-9"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowPassword((v) => !v)}
|
||||
className="absolute inset-y-0 right-0 flex items-center pr-2.5 text-muted-foreground hover:text-foreground"
|
||||
tabIndex={-1}
|
||||
>
|
||||
{showPassword ? (
|
||||
<EyeOff className="size-3.5" />
|
||||
) : (
|
||||
<Eye className="size-3.5" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Button type="submit" disabled={isLoggingIn} className="h-9 mt-1">
|
||||
{isLoggingIn ? (
|
||||
<>
|
||||
<Spinner size="sm" className="text-primary-foreground" />
|
||||
Signing in…
|
||||
</>
|
||||
) : (
|
||||
"Sign in"
|
||||
)}
|
||||
</Button>
|
||||
</form>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
|
||||
type PermissionStatus = "authorized" | "denied" | "not determined" | "restricted" | "limited";
|
||||
|
||||
|
|
@ -17,7 +18,8 @@ const STEPS = [
|
|||
{
|
||||
id: "screen-recording",
|
||||
title: "Screen Recording",
|
||||
description: "Lets SurfSense capture your screen to understand context and provide smart writing suggestions.",
|
||||
description:
|
||||
"Lets SurfSense capture your screen to understand context and provide smart writing suggestions.",
|
||||
action: "requestScreenRecording",
|
||||
field: "screenRecording" as const,
|
||||
},
|
||||
|
|
@ -57,19 +59,18 @@ function StatusBadge({ status }: { status: PermissionStatus }) {
|
|||
|
||||
export default function DesktopPermissionsPage() {
|
||||
const router = useRouter();
|
||||
const api = useElectronAPI();
|
||||
const [permissions, setPermissions] = useState<PermissionsStatus | null>(null);
|
||||
const [isElectron, setIsElectron] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI) return;
|
||||
setIsElectron(true);
|
||||
if (!api) return;
|
||||
|
||||
let interval: ReturnType<typeof setInterval> | null = null;
|
||||
|
||||
const isResolved = (s: string) => s === "authorized" || s === "restricted";
|
||||
|
||||
const poll = async () => {
|
||||
const status = await window.electronAPI!.getPermissionsStatus();
|
||||
const status = await api.getPermissionsStatus();
|
||||
setPermissions(status);
|
||||
|
||||
if (isResolved(status.accessibility) && isResolved(status.screenRecording)) {
|
||||
|
|
@ -79,10 +80,12 @@ export default function DesktopPermissionsPage() {
|
|||
|
||||
poll();
|
||||
interval = setInterval(poll, 2000);
|
||||
return () => { if (interval) clearInterval(interval); };
|
||||
}, []);
|
||||
return () => {
|
||||
if (interval) clearInterval(interval);
|
||||
};
|
||||
}, [api]);
|
||||
|
||||
if (!isElectron) {
|
||||
if (!api) {
|
||||
return (
|
||||
<div className="h-screen flex items-center justify-center bg-background">
|
||||
<p className="text-muted-foreground">This page is only available in the desktop app.</p>
|
||||
|
|
@ -98,19 +101,20 @@ export default function DesktopPermissionsPage() {
|
|||
);
|
||||
}
|
||||
|
||||
const allGranted = permissions.accessibility === "authorized" && permissions.screenRecording === "authorized";
|
||||
const allGranted =
|
||||
permissions.accessibility === "authorized" && permissions.screenRecording === "authorized";
|
||||
|
||||
const handleRequest = async (action: string) => {
|
||||
if (action === "requestScreenRecording") {
|
||||
await window.electronAPI!.requestScreenRecording();
|
||||
await api.requestScreenRecording();
|
||||
} else if (action === "requestAccessibility") {
|
||||
await window.electronAPI!.requestAccessibility();
|
||||
await api.requestAccessibility();
|
||||
}
|
||||
};
|
||||
|
||||
const handleContinue = () => {
|
||||
if (allGranted) {
|
||||
window.electronAPI!.restartApp();
|
||||
api.restartApp();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -175,7 +179,8 @@ export default function DesktopPermissionsPage() {
|
|||
</p>
|
||||
)}
|
||||
<p className="text-xs text-muted-foreground">
|
||||
If SurfSense doesn't appear in the list, click <strong>+</strong> and select it from Applications.
|
||||
If SurfSense doesn't appear in the list, click <strong>+</strong> and
|
||||
select it from Applications.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
|
@ -201,6 +206,7 @@ export default function DesktopPermissionsPage() {
|
|||
Grant permissions to continue
|
||||
</Button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleSkip}
|
||||
className="block mx-auto text-xs text-muted-foreground hover:text-foreground transition-colors"
|
||||
>
|
||||
|
|
|
|||
|
|
@ -4,10 +4,6 @@ export const metadata = {
|
|||
title: "SurfSense Suggestion",
|
||||
};
|
||||
|
||||
export default function SuggestionLayout({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
export default function SuggestionLayout({ children }: { children: React.ReactNode }) {
|
||||
return <div className="suggestion-body">{children}</div>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { getBearerToken } from "@/lib/auth-utils";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { ensureTokensFromElectron, getBearerToken } from "@/lib/auth-utils";
|
||||
|
||||
type SSEEvent =
|
||||
| { type: "text-delta"; id: string; delta: string }
|
||||
|
|
@ -9,7 +10,18 @@ type SSEEvent =
|
|||
| { type: "text-end"; id: string }
|
||||
| { type: "start"; messageId: string }
|
||||
| { type: "finish" }
|
||||
| { type: "error"; errorText: string };
|
||||
| { type: "error"; errorText: string }
|
||||
| {
|
||||
type: "data-thinking-step";
|
||||
data: { id: string; title: string; status: string; items: string[] };
|
||||
};
|
||||
|
||||
interface AgentStep {
|
||||
id: string;
|
||||
title: string;
|
||||
status: string;
|
||||
items: string[];
|
||||
}
|
||||
|
||||
function friendlyError(raw: string | number): string {
|
||||
if (typeof raw === "number") {
|
||||
|
|
@ -33,27 +45,52 @@ function friendlyError(raw: string | number): string {
|
|||
|
||||
const AUTO_DISMISS_MS = 3000;
|
||||
|
||||
function StepIcon({ status }: { status: string }) {
|
||||
if (status === "complete") {
|
||||
return (
|
||||
<svg
|
||||
className="step-icon step-icon-done"
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
aria-label="Step complete"
|
||||
>
|
||||
<circle cx="8" cy="8" r="7" stroke="#4ade80" strokeWidth="1.5" />
|
||||
<path
|
||||
d="M5 8.5l2 2 4-4.5"
|
||||
stroke="#4ade80"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
return <span className="step-spinner" />;
|
||||
}
|
||||
|
||||
export default function SuggestionPage() {
|
||||
const api = useElectronAPI();
|
||||
const [suggestion, setSuggestion] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isDesktop, setIsDesktop] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [steps, setSteps] = useState<AgentStep[]>([]);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
|
||||
const isDesktop = !!api?.onAutocompleteContext;
|
||||
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI?.onAutocompleteContext) {
|
||||
setIsDesktop(false);
|
||||
if (!api?.onAutocompleteContext) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, []);
|
||||
}, [api]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!error) return;
|
||||
const timer = setTimeout(() => {
|
||||
window.electronAPI?.dismissSuggestion?.();
|
||||
api?.dismissSuggestion?.();
|
||||
}, AUTO_DISMISS_MS);
|
||||
return () => clearTimeout(timer);
|
||||
}, [error]);
|
||||
}, [error, api]);
|
||||
|
||||
const fetchSuggestion = useCallback(
|
||||
async (screenshot: string, searchSpaceId: string, appName?: string, windowTitle?: string) => {
|
||||
|
|
@ -64,35 +101,36 @@ export default function SuggestionPage() {
|
|||
setIsLoading(true);
|
||||
setSuggestion("");
|
||||
setError(null);
|
||||
setSteps([]);
|
||||
|
||||
const token = getBearerToken();
|
||||
let token = getBearerToken();
|
||||
if (!token) {
|
||||
await ensureTokensFromElectron();
|
||||
token = getBearerToken();
|
||||
}
|
||||
if (!token) {
|
||||
setError(friendlyError("not authenticated"));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const backendUrl =
|
||||
process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${backendUrl}/api/v1/autocomplete/vision/stream`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
screenshot,
|
||||
search_space_id: parseInt(searchSpaceId, 10),
|
||||
app_name: appName || "",
|
||||
window_title: windowTitle || "",
|
||||
}),
|
||||
signal: controller.signal,
|
||||
const response = await fetch(`${backendUrl}/api/v1/autocomplete/vision/stream`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
);
|
||||
body: JSON.stringify({
|
||||
screenshot,
|
||||
search_space_id: parseInt(searchSpaceId, 10),
|
||||
app_name: appName || "",
|
||||
window_title: windowTitle || "",
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
setError(friendlyError(response.status));
|
||||
|
|
@ -131,10 +169,19 @@ export default function SuggestionPage() {
|
|||
setSuggestion((prev) => prev + parsed.delta);
|
||||
} else if (parsed.type === "error") {
|
||||
setError(friendlyError(parsed.errorText));
|
||||
} else if (parsed.type === "data-thinking-step") {
|
||||
const { id, title, status, items } = parsed.data;
|
||||
setSteps((prev) => {
|
||||
const existing = prev.findIndex((s) => s.id === id);
|
||||
if (existing >= 0) {
|
||||
const updated = [...prev];
|
||||
updated[existing] = { id, title, status, items };
|
||||
return updated;
|
||||
}
|
||||
return [...prev, { id, title, status, items }];
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -145,13 +192,13 @@ export default function SuggestionPage() {
|
|||
setIsLoading(false);
|
||||
}
|
||||
},
|
||||
[],
|
||||
[]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI?.onAutocompleteContext) return;
|
||||
if (!api?.onAutocompleteContext) return;
|
||||
|
||||
const cleanup = window.electronAPI.onAutocompleteContext((data) => {
|
||||
const cleanup = api.onAutocompleteContext((data) => {
|
||||
const searchSpaceId = data.searchSpaceId || "1";
|
||||
if (data.screenshot) {
|
||||
fetchSuggestion(data.screenshot, searchSpaceId, data.appName, data.windowTitle);
|
||||
|
|
@ -159,7 +206,7 @@ export default function SuggestionPage() {
|
|||
});
|
||||
|
||||
return cleanup;
|
||||
}, [fetchSuggestion]);
|
||||
}, [fetchSuggestion, api]);
|
||||
|
||||
if (!isDesktop) {
|
||||
return (
|
||||
|
|
@ -179,13 +226,33 @@ export default function SuggestionPage() {
|
|||
);
|
||||
}
|
||||
|
||||
if (isLoading && !suggestion) {
|
||||
const showLoading = isLoading && !suggestion;
|
||||
|
||||
if (showLoading) {
|
||||
return (
|
||||
<div className="suggestion-tooltip">
|
||||
<div className="suggestion-loading">
|
||||
<span className="suggestion-dot" />
|
||||
<span className="suggestion-dot" />
|
||||
<span className="suggestion-dot" />
|
||||
<div className="agent-activity">
|
||||
{steps.length === 0 && (
|
||||
<div className="activity-initial">
|
||||
<span className="step-spinner" />
|
||||
<span className="activity-label">Preparing…</span>
|
||||
</div>
|
||||
)}
|
||||
{steps.length > 0 && (
|
||||
<div className="activity-steps">
|
||||
{steps.map((step) => (
|
||||
<div key={step.id} className="activity-step">
|
||||
<StepIcon status={step.status} />
|
||||
<span className="step-label">
|
||||
{step.title}
|
||||
{step.items.length > 0 && (
|
||||
<span className="step-detail"> · {step.items[0]}</span>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
@ -193,12 +260,12 @@ export default function SuggestionPage() {
|
|||
|
||||
const handleAccept = () => {
|
||||
if (suggestion) {
|
||||
window.electronAPI?.acceptSuggestion?.(suggestion);
|
||||
api?.acceptSuggestion?.(suggestion);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDismiss = () => {
|
||||
window.electronAPI?.dismissSuggestion?.();
|
||||
api?.dismissSuggestion?.();
|
||||
};
|
||||
|
||||
if (!suggestion) return null;
|
||||
|
|
@ -207,10 +274,18 @@ export default function SuggestionPage() {
|
|||
<div className="suggestion-tooltip">
|
||||
<p className="suggestion-text">{suggestion}</p>
|
||||
<div className="suggestion-actions">
|
||||
<button className="suggestion-btn suggestion-btn-accept" onClick={handleAccept}>
|
||||
<button
|
||||
type="button"
|
||||
className="suggestion-btn suggestion-btn-accept"
|
||||
onClick={handleAccept}
|
||||
>
|
||||
Accept
|
||||
</button>
|
||||
<button className="suggestion-btn suggestion-btn-dismiss" onClick={handleDismiss}>
|
||||
<button
|
||||
type="button"
|
||||
className="suggestion-btn suggestion-btn-dismiss"
|
||||
onClick={handleDismiss}
|
||||
>
|
||||
Dismiss
|
||||
</button>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -1,121 +1,193 @@
|
|||
html:has(.suggestion-body),
|
||||
body:has(.suggestion-body) {
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
background: transparent !important;
|
||||
overflow: hidden !important;
|
||||
height: auto !important;
|
||||
width: 100% !important;
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
background: transparent !important;
|
||||
overflow: hidden !important;
|
||||
height: auto !important;
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
.suggestion-body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
background: transparent;
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
user-select: none;
|
||||
-webkit-app-region: no-drag;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
background: transparent;
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
user-select: none;
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.suggestion-tooltip {
|
||||
background: #1e1e1e;
|
||||
border: 1px solid #3c3c3c;
|
||||
border-radius: 8px;
|
||||
padding: 8px 12px;
|
||||
margin: 4px;
|
||||
max-width: 400px;
|
||||
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.5);
|
||||
box-sizing: border-box;
|
||||
background: #1e1e1e;
|
||||
border: 1px solid #3c3c3c;
|
||||
border-radius: 8px;
|
||||
padding: 8px 12px;
|
||||
margin: 4px;
|
||||
max-width: 400px;
|
||||
/* MAX_HEIGHT in suggestion-window.ts is 400px. Subtract 8px for margin
|
||||
(4px * 2) so the tooltip + margin fits within the Electron window.
|
||||
box-sizing: border-box ensures padding + border are included. */
|
||||
max-height: 392px;
|
||||
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.5);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.suggestion-text {
|
||||
color: #d4d4d4;
|
||||
font-size: 13px;
|
||||
line-height: 1.45;
|
||||
margin: 0 0 6px 0;
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
color: #d4d4d4;
|
||||
font-size: 13px;
|
||||
line-height: 1.45;
|
||||
margin: 0 0 6px 0;
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
overflow-y: auto;
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.suggestion-text::-webkit-scrollbar {
|
||||
width: 5px;
|
||||
}
|
||||
|
||||
.suggestion-text::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.suggestion-text::-webkit-scrollbar-thumb {
|
||||
background: #555;
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.suggestion-text::-webkit-scrollbar-thumb:hover {
|
||||
background: #777;
|
||||
}
|
||||
|
||||
.suggestion-actions {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
gap: 4px;
|
||||
border-top: 1px solid #2a2a2a;
|
||||
padding-top: 6px;
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
gap: 4px;
|
||||
border-top: 1px solid #2a2a2a;
|
||||
padding-top: 6px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.suggestion-btn {
|
||||
padding: 2px 8px;
|
||||
border-radius: 3px;
|
||||
border: 1px solid #3c3c3c;
|
||||
font-family: inherit;
|
||||
font-size: 10px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
line-height: 16px;
|
||||
transition: background 0.15s, border-color 0.15s;
|
||||
padding: 2px 8px;
|
||||
border-radius: 3px;
|
||||
border: 1px solid #3c3c3c;
|
||||
font-family: inherit;
|
||||
font-size: 10px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
line-height: 16px;
|
||||
transition:
|
||||
background 0.15s,
|
||||
border-color 0.15s;
|
||||
}
|
||||
|
||||
.suggestion-btn-accept {
|
||||
background: #2563eb;
|
||||
border-color: #3b82f6;
|
||||
color: #fff;
|
||||
background: #2563eb;
|
||||
border-color: #3b82f6;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
.suggestion-btn-accept:hover {
|
||||
background: #1d4ed8;
|
||||
background: #1d4ed8;
|
||||
}
|
||||
|
||||
.suggestion-btn-dismiss {
|
||||
background: #2a2a2a;
|
||||
color: #999;
|
||||
background: #2a2a2a;
|
||||
color: #999;
|
||||
}
|
||||
|
||||
.suggestion-btn-dismiss:hover {
|
||||
background: #333;
|
||||
color: #ccc;
|
||||
background: #333;
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
.suggestion-error {
|
||||
border-color: #5c2626;
|
||||
border-color: #5c2626;
|
||||
}
|
||||
|
||||
.suggestion-error-text {
|
||||
color: #f48771;
|
||||
font-size: 12px;
|
||||
color: #f48771;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.suggestion-loading {
|
||||
display: flex;
|
||||
gap: 5px;
|
||||
padding: 2px 0;
|
||||
justify-content: center;
|
||||
/* --- Agent activity indicator --- */
|
||||
|
||||
.agent-activity {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
overflow-y: auto;
|
||||
max-height: 340px;
|
||||
}
|
||||
|
||||
.suggestion-dot {
|
||||
width: 4px;
|
||||
height: 4px;
|
||||
border-radius: 50%;
|
||||
background: #666;
|
||||
animation: suggestion-pulse 1.2s infinite ease-in-out;
|
||||
.activity-initial {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 2px 0;
|
||||
}
|
||||
|
||||
.suggestion-dot:nth-child(2) {
|
||||
animation-delay: 0.15s;
|
||||
.activity-label {
|
||||
color: #a1a1aa;
|
||||
font-size: 12px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.suggestion-dot:nth-child(3) {
|
||||
animation-delay: 0.3s;
|
||||
.activity-steps {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 3px;
|
||||
}
|
||||
|
||||
@keyframes suggestion-pulse {
|
||||
0%, 80%, 100% {
|
||||
opacity: 0.3;
|
||||
transform: scale(0.8);
|
||||
}
|
||||
40% {
|
||||
opacity: 1;
|
||||
transform: scale(1.1);
|
||||
}
|
||||
.activity-step {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
min-height: 18px;
|
||||
}
|
||||
|
||||
.step-label {
|
||||
color: #d4d4d4;
|
||||
font-size: 12px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.step-detail {
|
||||
color: #71717a;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
/* Spinner (in_progress) */
|
||||
.step-spinner {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
flex-shrink: 0;
|
||||
border: 1.5px solid #3f3f46;
|
||||
border-top-color: #a78bfa;
|
||||
border-radius: 50%;
|
||||
animation: step-spin 0.7s linear infinite;
|
||||
}
|
||||
|
||||
/* Checkmark icon (complete) */
|
||||
.step-icon {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
@keyframes step-spin {
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import { ZeroProvider } from "@/components/providers/ZeroProvider";
|
|||
import { ThemeProvider } from "@/components/theme/theme-provider";
|
||||
import { Toaster } from "@/components/ui/sonner";
|
||||
import { LocaleProvider } from "@/contexts/LocaleContext";
|
||||
import { PlatformProvider } from "@/contexts/platform-context";
|
||||
import { ReactQueryClientProvider } from "@/lib/query-client/query-client.provider";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
|
|
@ -139,15 +140,17 @@ export default function RootLayout({
|
|||
disableTransitionOnChange
|
||||
defaultTheme="system"
|
||||
>
|
||||
<RootProvider>
|
||||
<ReactQueryClientProvider>
|
||||
<ZeroProvider>
|
||||
<GlobalLoadingProvider>{children}</GlobalLoadingProvider>
|
||||
</ZeroProvider>
|
||||
</ReactQueryClientProvider>
|
||||
<Toaster />
|
||||
<AnnouncementToastProvider />
|
||||
</RootProvider>
|
||||
<PlatformProvider>
|
||||
<RootProvider>
|
||||
<ReactQueryClientProvider>
|
||||
<ZeroProvider>
|
||||
<GlobalLoadingProvider>{children}</GlobalLoadingProvider>
|
||||
</ZeroProvider>
|
||||
</ReactQueryClientProvider>
|
||||
<Toaster />
|
||||
<AnnouncementToastProvider />
|
||||
</RootProvider>
|
||||
</PlatformProvider>
|
||||
</ThemeProvider>
|
||||
</I18nProvider>
|
||||
</LocaleProvider>
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import { useEffect } from "react";
|
||||
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
|
||||
import { getAndClearRedirectPath, setBearerToken, setRefreshToken } from "@/lib/auth-utils";
|
||||
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
|
||||
import { trackLoginSuccess } from "@/lib/posthog/events";
|
||||
|
||||
interface TokenHandlerProps {
|
||||
|
|
@ -29,52 +30,54 @@ const TokenHandler = ({
|
|||
useGlobalLoadingEffect(true);
|
||||
|
||||
useEffect(() => {
|
||||
// Only run on client-side
|
||||
if (typeof window === "undefined") return;
|
||||
|
||||
// Read tokens from URL at mount time — no subscription needed.
|
||||
// TokenHandler only runs once after an auth redirect, so a stale read
|
||||
// is impossible and useSearchParams() would add a pointless subscription.
|
||||
// (Vercel Best Practice: rerender-defer-reads 5.2)
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const token = params.get(tokenParamName);
|
||||
const refreshToken = params.get("refresh_token");
|
||||
const run = async () => {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const token = params.get(tokenParamName);
|
||||
const refreshToken = params.get("refresh_token");
|
||||
|
||||
if (token) {
|
||||
try {
|
||||
// Track login success for OAuth flows (e.g., Google)
|
||||
// Local login already tracks success before redirecting here
|
||||
const alreadyTracked = sessionStorage.getItem("login_success_tracked");
|
||||
if (!alreadyTracked) {
|
||||
// This is an OAuth flow (Google login) - track success
|
||||
trackLoginSuccess("google");
|
||||
if (token) {
|
||||
try {
|
||||
const alreadyTracked = sessionStorage.getItem("login_success_tracked");
|
||||
if (!alreadyTracked) {
|
||||
trackLoginSuccess("google");
|
||||
}
|
||||
sessionStorage.removeItem("login_success_tracked");
|
||||
|
||||
localStorage.setItem(storageKey, token);
|
||||
setBearerToken(token);
|
||||
|
||||
if (refreshToken) {
|
||||
setRefreshToken(refreshToken);
|
||||
}
|
||||
|
||||
// Auto-set active search space in desktop if not already set
|
||||
if (window.electronAPI?.getActiveSearchSpace) {
|
||||
try {
|
||||
const stored = await window.electronAPI.getActiveSearchSpace();
|
||||
if (!stored) {
|
||||
const spaces = await searchSpacesApiService.getSearchSpaces();
|
||||
if (spaces?.length) {
|
||||
await window.electronAPI.setActiveSearchSpace?.(String(spaces[0].id));
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// non-critical
|
||||
}
|
||||
}
|
||||
|
||||
const savedRedirectPath = getAndClearRedirectPath();
|
||||
const finalRedirectPath = savedRedirectPath || redirectPath;
|
||||
window.location.href = finalRedirectPath;
|
||||
} catch (error) {
|
||||
console.error("Error storing token in localStorage:", error);
|
||||
window.location.href = redirectPath;
|
||||
}
|
||||
// Clear the flag for future logins
|
||||
sessionStorage.removeItem("login_success_tracked");
|
||||
|
||||
// Store access token in localStorage using both methods for compatibility
|
||||
localStorage.setItem(storageKey, token);
|
||||
setBearerToken(token);
|
||||
|
||||
// Store refresh token if provided
|
||||
if (refreshToken) {
|
||||
setRefreshToken(refreshToken);
|
||||
}
|
||||
|
||||
// Check if there's a saved redirect path from before the auth flow
|
||||
const savedRedirectPath = getAndClearRedirectPath();
|
||||
|
||||
// Use the saved path if available, otherwise use the default redirectPath
|
||||
const finalRedirectPath = savedRedirectPath || redirectPath;
|
||||
|
||||
// Redirect to the appropriate path
|
||||
window.location.href = finalRedirectPath;
|
||||
} catch (error) {
|
||||
console.error("Error storing token in localStorage:", error);
|
||||
// Even if there's an error, try to redirect to the default path
|
||||
window.location.href = redirectPath;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
run();
|
||||
}, [tokenParamName, storageKey, redirectPath]);
|
||||
|
||||
// Return null - the global provider handles the loading UI
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import {
|
|||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { logout } from "@/lib/auth-utils";
|
||||
import { getLoginPath, logout } from "@/lib/auth-utils";
|
||||
import { resetUser, trackLogout } from "@/lib/posthog/events";
|
||||
|
||||
export function UserDropdown({
|
||||
|
|
@ -33,22 +33,19 @@ export function UserDropdown({
|
|||
if (isLoggingOut) return;
|
||||
setIsLoggingOut(true);
|
||||
try {
|
||||
// Track logout event and reset PostHog identity
|
||||
trackLogout();
|
||||
resetUser();
|
||||
|
||||
// Revoke refresh token on server and clear all tokens from localStorage
|
||||
await logout();
|
||||
|
||||
if (typeof window !== "undefined") {
|
||||
window.location.href = "/";
|
||||
window.location.href = getLoginPath();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error during logout:", error);
|
||||
// Even if there's an error, try to clear tokens and redirect
|
||||
await logout();
|
||||
if (typeof window !== "undefined") {
|
||||
window.location.href = "/";
|
||||
window.location.href = getLoginPath();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ import {
|
|||
} from "@/components/ui/drawer";
|
||||
import { useComments } from "@/hooks/use-comments";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
// Dynamically import video presentation tool to avoid loading Babel and Remotion in main bundle
|
||||
|
|
@ -463,16 +464,15 @@ export const AssistantMessage: FC = () => {
|
|||
const AssistantActionBar: FC = () => {
|
||||
const isLast = useAuiState((s) => s.message.isLast);
|
||||
const aui = useAui();
|
||||
const [quickAskMode, setQuickAskMode] = useState("");
|
||||
const api = useElectronAPI();
|
||||
const [isQuickAssist, setIsQuickAssist] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isLast || !window.electronAPI?.getQuickAskMode) return;
|
||||
window.electronAPI.getQuickAskMode().then((mode) => {
|
||||
if (mode) setQuickAskMode(mode);
|
||||
if (!api?.getQuickAskMode) return;
|
||||
api.getQuickAskMode().then((mode) => {
|
||||
if (mode) setIsQuickAssist(true);
|
||||
});
|
||||
}, [isLast]);
|
||||
|
||||
const isTransform = isLast && !!window.electronAPI?.replaceText && quickAskMode === "transform";
|
||||
}, [api]);
|
||||
|
||||
return (
|
||||
<ActionBarPrimitive.Root
|
||||
|
|
@ -482,7 +482,7 @@ const AssistantActionBar: FC = () => {
|
|||
className="aui-assistant-action-bar-root -ml-1 col-start-3 row-start-2 flex gap-1 text-muted-foreground md:data-floating:absolute md:data-floating:rounded-md md:data-floating:p-1 [&>button]:opacity-100 md:[&>button]:opacity-[var(--aui-button-opacity,1)]"
|
||||
>
|
||||
<ActionBarPrimitive.Copy asChild>
|
||||
<TooltipIconButton tooltip="Copy">
|
||||
<TooltipIconButton tooltip="Copy to clipboard">
|
||||
<AuiIf condition={({ message }) => message.isCopied}>
|
||||
<CheckIcon />
|
||||
</AuiIf>
|
||||
|
|
@ -492,29 +492,27 @@ const AssistantActionBar: FC = () => {
|
|||
</TooltipIconButton>
|
||||
</ActionBarPrimitive.Copy>
|
||||
<ActionBarPrimitive.ExportMarkdown asChild>
|
||||
<TooltipIconButton tooltip="Download">
|
||||
<TooltipIconButton tooltip="Download as Markdown">
|
||||
<DownloadIcon />
|
||||
</TooltipIconButton>
|
||||
</ActionBarPrimitive.ExportMarkdown>
|
||||
{isLast && (
|
||||
<ActionBarPrimitive.Reload asChild>
|
||||
<TooltipIconButton tooltip="Refresh">
|
||||
<TooltipIconButton tooltip="Regenerate response">
|
||||
<RefreshCwIcon />
|
||||
</TooltipIconButton>
|
||||
</ActionBarPrimitive.Reload>
|
||||
)}
|
||||
{isTransform && (
|
||||
<button
|
||||
type="button"
|
||||
{isQuickAssist && (
|
||||
<TooltipIconButton
|
||||
tooltip="Paste back into source app"
|
||||
onClick={() => {
|
||||
const text = aui.message().getCopyText();
|
||||
window.electronAPI?.replaceText(text);
|
||||
api?.replaceText(text);
|
||||
}}
|
||||
className="ml-1 inline-flex items-center gap-1.5 rounded-md bg-primary px-3 py-1.5 text-xs font-medium text-primary-foreground transition-colors hover:bg-primary/90"
|
||||
>
|
||||
<ClipboardPaste className="size-3.5" />
|
||||
Paste back
|
||||
</button>
|
||||
<ClipboardPaste />
|
||||
</TooltipIconButton>
|
||||
)}
|
||||
</ActionBarPrimitive.Root>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -216,7 +216,7 @@ export const ConnectorIndicator = forwardRef<ConnectorIndicatorHandle, Connector
|
|||
onPointerDownOutside={(e) => {
|
||||
if (pickerOpen) e.preventDefault();
|
||||
}}
|
||||
className="max-w-3xl w-[95vw] sm:w-full h-[75vh] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border ring-0 dark:ring-0 bg-muted dark:bg-muted text-foreground [&>button]:right-4 sm:[&>button]:right-12 [&>button]:top-6 sm:[&>button]:top-10 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button_svg]:size-5 select-none"
|
||||
className="max-w-3xl w-[95vw] sm:w-full h-[75vh] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border ring-0 dark:ring-0 bg-muted dark:bg-muted text-foreground [&>button]:right-4 sm:[&>button]:right-12 [&>button]:top-6 sm:[&>button]:top-10 [&>button]:opacity-80 [&>button]:hover:opacity-100 [&>button]:hover:bg-foreground/10 [&>button>svg]:size-5 select-none"
|
||||
>
|
||||
<DialogTitle className="sr-only">Manage Connectors</DialogTitle>
|
||||
{/* YouTube Crawler View - shown when adding YouTube videos */}
|
||||
|
|
|
|||
|
|
@ -144,18 +144,14 @@ export const ConnectorConnectView: FC<ConnectorConnectViewProps> = ({
|
|||
type="button"
|
||||
onClick={handleFormSubmit}
|
||||
disabled={isSubmitting}
|
||||
className="text-xs sm:text-sm min-w-[140px] disabled:opacity-50 disabled:cursor-not-allowed disabled:pointer-events-none"
|
||||
className="relative text-xs sm:text-sm min-w-[140px] disabled:opacity-50 disabled:cursor-not-allowed disabled:pointer-events-none"
|
||||
>
|
||||
{isSubmitting ? (
|
||||
<>
|
||||
<Spinner size="sm" className="mr-2" />
|
||||
Connecting
|
||||
</>
|
||||
) : connectorType === "MCP_CONNECTOR" ? (
|
||||
"Connect"
|
||||
) : (
|
||||
`Connect ${getConnectorTypeDisplay(connectorType)}`
|
||||
)}
|
||||
<span className={isSubmitting ? "opacity-0" : ""}>
|
||||
{connectorType === "MCP_CONNECTOR"
|
||||
? "Connect"
|
||||
: `Connect ${getConnectorTypeDisplay(connectorType)}`}
|
||||
</span>
|
||||
{isSubmitting && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -369,16 +369,10 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
size="sm"
|
||||
onClick={handleDisconnectConfirm}
|
||||
disabled={isDisconnecting}
|
||||
className="text-xs sm:text-sm flex-1 sm:flex-initial h-10 sm:h-auto py-2 sm:py-2"
|
||||
className="relative text-xs sm:text-sm flex-1 sm:flex-initial h-10 sm:h-auto py-2 sm:py-2"
|
||||
>
|
||||
{isDisconnecting ? (
|
||||
<>
|
||||
<Spinner size="sm" className="mr-2" />
|
||||
Disconnecting
|
||||
</>
|
||||
) : (
|
||||
"Confirm Disconnect"
|
||||
)}
|
||||
<span className={isDisconnecting ? "opacity-0" : ""}>Confirm Disconnect</span>
|
||||
{isDisconnecting && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
|
|
@ -415,16 +409,10 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
<Button
|
||||
onClick={onSave}
|
||||
disabled={isSaving || isDisconnecting}
|
||||
className="text-xs sm:text-sm flex-1 sm:flex-initial h-12 sm:h-auto py-3 sm:py-2"
|
||||
className="relative text-xs sm:text-sm flex-1 sm:flex-initial h-12 sm:h-auto py-3 sm:py-2"
|
||||
>
|
||||
{isSaving ? (
|
||||
<>
|
||||
<Spinner size="sm" className="mr-2" />
|
||||
Saving
|
||||
</>
|
||||
) : (
|
||||
"Save Changes"
|
||||
)}
|
||||
<span className={isSaving ? "opacity-0" : ""}>Save Changes</span>
|
||||
{isSaving && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { Cable } from "lucide-react";
|
||||
import { Search, Unplug } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { getDocumentTypeLabel } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
|
@ -134,9 +134,17 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
const hasActiveConnectors =
|
||||
filteredOAuthConnectorTypes.length > 0 || filteredNonOAuthConnectors.length > 0;
|
||||
|
||||
const hasFilteredResults = hasActiveConnectors || standaloneDocuments.length > 0;
|
||||
|
||||
return (
|
||||
<TabsContent value="active" className="m-0">
|
||||
{hasSources ? (
|
||||
{hasSources && !hasFilteredResults && searchQuery ? (
|
||||
<div className="flex flex-col items-center justify-center py-20 text-center">
|
||||
<Search className="size-8 text-muted-foreground mb-3" />
|
||||
<p className="text-sm text-muted-foreground">No connectors found</p>
|
||||
<p className="text-xs text-muted-foreground/60 mt-1">Try a different search term</p>
|
||||
</div>
|
||||
) : hasSources ? (
|
||||
<div className="space-y-6">
|
||||
{/* Active Connectors Section */}
|
||||
{hasActiveConnectors && (
|
||||
|
|
@ -302,7 +310,7 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
) : (
|
||||
<div className="flex flex-col items-center justify-center py-20 text-center">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-full bg-muted mb-4">
|
||||
<Cable className="size-8 text-muted-foreground" />
|
||||
<Unplug className="size-8 text-muted-foreground" />
|
||||
</div>
|
||||
<h4 className="text-lg font-semibold">No active sources</h4>
|
||||
<p className="text-sm text-muted-foreground mt-1 max-w-[280px]">
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
"use client";
|
||||
|
||||
import { Search } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
import { usePlatform } from "@/hooks/use-platform";
|
||||
import { isSelfHosted } from "@/lib/env-config";
|
||||
import { ConnectorCard } from "../components/connector-card";
|
||||
import {
|
||||
|
|
@ -74,9 +76,8 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
onManage,
|
||||
onViewAccountsList,
|
||||
}) => {
|
||||
// Check if self-hosted mode (for showing self-hosted only connectors)
|
||||
const selfHosted = isSelfHosted();
|
||||
const isDesktop = typeof window !== "undefined" && !!window.electronAPI;
|
||||
const { isDesktop } = usePlatform();
|
||||
|
||||
const matchesSearch = (title: string, description: string) =>
|
||||
title.toLowerCase().includes(searchQuery.toLowerCase()) ||
|
||||
|
|
@ -287,6 +288,18 @@ export const AllConnectorsTab: FC<AllConnectorsTabProps> = ({
|
|||
moreIntegrationsOther.length > 0 ||
|
||||
moreIntegrationsCrawlers.length > 0;
|
||||
|
||||
const hasAnyResults = hasDocumentFileConnectors || hasMoreIntegrations;
|
||||
|
||||
if (!hasAnyResults && searchQuery) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-20 text-center">
|
||||
<Search className="size-8 text-muted-foreground mb-3" />
|
||||
<p className="text-sm text-muted-foreground">No connectors found</p>
|
||||
<p className="text-xs text-muted-foreground/60 mt-1">Try a different search term</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-8">
|
||||
{/* Document/Files Connectors */}
|
||||
|
|
|
|||
|
|
@ -173,9 +173,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
<Plus className="size-3 text-primary" />
|
||||
)}
|
||||
</div>
|
||||
<span className="text-xs sm:text-sm font-medium">
|
||||
{isConnecting ? "Connecting" : buttonText}
|
||||
</span>
|
||||
<span className="text-xs sm:text-sm font-medium">{buttonText}</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -335,16 +335,10 @@ export const YouTubeCrawlerView: FC<YouTubeCrawlerViewProps> = ({ searchSpaceId,
|
|||
<Button
|
||||
onClick={handleSubmit}
|
||||
disabled={isSubmitting || isFetchingPlaylist || videoTags.length === 0}
|
||||
className="text-xs sm:text-sm min-w-[140px] disabled:opacity-50 disabled:cursor-not-allowed disabled:pointer-events-none"
|
||||
className="relative text-xs sm:text-sm min-w-[140px] disabled:opacity-50 disabled:cursor-not-allowed disabled:pointer-events-none"
|
||||
>
|
||||
{isSubmitting ? (
|
||||
<>
|
||||
<Spinner size="sm" className="mr-2" />
|
||||
{t("processing")}
|
||||
</>
|
||||
) : (
|
||||
t("submit")
|
||||
)}
|
||||
<span className={isSubmitting ? "opacity-0" : ""}>{t("submit")}</span>
|
||||
{isSubmitting && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -125,18 +125,16 @@ const DocumentUploadPopupContent: FC<{
|
|||
onPointerDownOutside={(e) => e.preventDefault()}
|
||||
onInteractOutside={(e) => e.preventDefault()}
|
||||
onEscapeKeyDown={(e) => e.preventDefault()}
|
||||
className="select-none max-w-2xl w-[95vw] sm:w-[640px] h-[min(440px,75dvh)] sm:h-[min(500px,80vh)] flex flex-col p-0 gap-0 overflow-hidden border border-border ring-0 bg-muted dark:bg-muted text-foreground [&>button]:right-3 sm:[&>button]:right-6 [&>button]:top-3 sm:[&>button]:top-5 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button]:z-[100] [&>button_svg]:size-4 sm:[&>button_svg]:size-5"
|
||||
className="select-none max-w-2xl w-[95vw] sm:w-[640px] h-[min(440px,75dvh)] sm:h-[min(520px,80vh)] flex flex-col p-0 gap-0 overflow-hidden border border-border ring-0 bg-muted dark:bg-muted text-foreground [&>button]:right-3 sm:[&>button]:right-6 [&>button]:top-5 sm:[&>button]:top-8 [&>button]:opacity-80 [&>button]:hover:opacity-100 [&>button]:hover:bg-foreground/10 [&>button]:z-[100] [&>button>svg]:size-4 sm:[&>button>svg]:size-5"
|
||||
>
|
||||
<DialogTitle className="sr-only">Upload Document</DialogTitle>
|
||||
|
||||
<div className="flex-1 min-h-0 overflow-y-auto overscroll-contain">
|
||||
<div className="sticky top-0 z-20 bg-muted px-4 sm:px-6 pt-4 sm:pt-5 pb-10">
|
||||
<div className="sticky top-0 z-20 bg-muted px-4 sm:px-6 pt-6 sm:pt-8 pb-10">
|
||||
<div className="flex items-center gap-2 mb-1 pr-8 sm:pr-0">
|
||||
<h2 className="text-base sm:text-lg font-semibold tracking-tight">
|
||||
Upload Documents
|
||||
</h2>
|
||||
<h2 className="text-xl sm:text-3xl font-semibold tracking-tight">Upload Documents</h2>
|
||||
</div>
|
||||
<p className="text-xs sm:text-sm text-muted-foreground line-clamp-1">
|
||||
<p className="text-xs sm:text-base text-muted-foreground/80 line-clamp-1">
|
||||
Upload and sync your documents to your search space
|
||||
</p>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@
|
|||
import type { ImageMessagePartComponent } from "@assistant-ui/react";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
import { ImageIcon, ImageOffIcon } from "lucide-react";
|
||||
import NextImage from "next/image";
|
||||
import { memo, type PropsWithChildren, useEffect, useRef, useState } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
import { cn } from "@/lib/utils";
|
||||
import NextImage from 'next/image';
|
||||
|
||||
const imageVariants = cva("aui-image-root relative overflow-hidden rounded-lg", {
|
||||
variants: {
|
||||
|
|
@ -88,23 +88,23 @@ function ImagePreview({
|
|||
<ImageOffIcon className="size-8 text-muted-foreground" />
|
||||
</div>
|
||||
) : isDataOrBlobUrl(src) ? (
|
||||
// biome-ignore lint/performance/noImgElement: data/blob URLs need plain img
|
||||
<img
|
||||
ref={imgRef}
|
||||
src={src}
|
||||
alt={alt}
|
||||
className={cn("block h-auto w-full object-contain", !loaded && "invisible", className)}
|
||||
onLoad={(e) => {
|
||||
if (typeof src === "string") setLoadedSrc(src);
|
||||
onLoad?.(e);
|
||||
}}
|
||||
onError={(e) => {
|
||||
if (typeof src === "string") setErrorSrc(src);
|
||||
onError?.(e);
|
||||
}}
|
||||
{...props}
|
||||
/>
|
||||
) : (
|
||||
// biome-ignore lint/performance/noImgElement: data/blob URLs need plain img
|
||||
<img
|
||||
ref={imgRef}
|
||||
src={src}
|
||||
alt={alt}
|
||||
className={cn("block h-auto w-full object-contain", !loaded && "invisible", className)}
|
||||
onLoad={(e) => {
|
||||
if (typeof src === "string") setLoadedSrc(src);
|
||||
onLoad?.(e);
|
||||
}}
|
||||
onError={(e) => {
|
||||
if (typeof src === "string") setErrorSrc(src);
|
||||
onError?.(e);
|
||||
}}
|
||||
{...props}
|
||||
/>
|
||||
) : (
|
||||
// biome-ignore lint/performance/noImgElement: intentional for dynamic external URLs
|
||||
// <img
|
||||
// ref={imgRef}
|
||||
|
|
@ -122,22 +122,22 @@ function ImagePreview({
|
|||
// {...props}
|
||||
// />
|
||||
<NextImage
|
||||
fill
|
||||
src={src || ""}
|
||||
alt={alt}
|
||||
sizes="(max-width: 768px) 100vw, (max-width: 1200px) 80vw, 60vw"
|
||||
className={cn("block object-contain", !loaded && "invisible", className)}
|
||||
onLoad={() => {
|
||||
if (typeof src === "string") setLoadedSrc(src);
|
||||
onLoad?.();
|
||||
}}
|
||||
onError={() => {
|
||||
if (typeof src === "string") setErrorSrc(src);
|
||||
onError?.();
|
||||
}}
|
||||
unoptimized={false}
|
||||
{...props}
|
||||
/>
|
||||
fill
|
||||
src={src || ""}
|
||||
alt={alt}
|
||||
sizes="(max-width: 768px) 100vw, (max-width: 1200px) 80vw, 60vw"
|
||||
className={cn("block object-contain", !loaded && "invisible", className)}
|
||||
onLoad={() => {
|
||||
if (typeof src === "string") setLoadedSrc(src);
|
||||
onLoad?.();
|
||||
}}
|
||||
onError={() => {
|
||||
if (typeof src === "string") setErrorSrc(src);
|
||||
onError?.();
|
||||
}}
|
||||
unoptimized={false}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
|
@ -162,8 +162,8 @@ type ImageZoomProps = PropsWithChildren<{
|
|||
alt?: string;
|
||||
}>;
|
||||
function isDataOrBlobUrl(src: string | undefined): boolean {
|
||||
if (!src || typeof src !== "string") return false;
|
||||
return src.startsWith("data:") || src.startsWith("blob:");
|
||||
if (!src || typeof src !== "string") return false;
|
||||
return src.startsWith("data:") || src.startsWith("blob:");
|
||||
}
|
||||
function ImageZoom({ src, alt = "Image preview", children }: ImageZoomProps) {
|
||||
const [isMounted, setIsMounted] = useState(false);
|
||||
|
|
@ -216,38 +216,38 @@ function ImageZoom({ src, alt = "Image preview", children }: ImageZoomProps) {
|
|||
>
|
||||
{/** biome-ignore lint/performance/noImgElement: <explanation> */}
|
||||
{isDataOrBlobUrl(src) ? (
|
||||
// biome-ignore lint/performance/noImgElement: data/blob URLs need plain img
|
||||
<img
|
||||
data-slot="image-zoom-content"
|
||||
src={src}
|
||||
alt={alt}
|
||||
className="aui-image-zoom-content fade-in zoom-in-95 max-h-[90vh] max-w-[90vw] animate-in object-contain duration-200"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleClose();
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.stopPropagation();
|
||||
handleClose();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
// biome-ignore lint/performance/noImgElement: data/blob URLs need plain img
|
||||
<img
|
||||
data-slot="image-zoom-content"
|
||||
src={src}
|
||||
alt={alt}
|
||||
className="aui-image-zoom-content fade-in zoom-in-95 max-h-[90vh] max-w-[90vw] animate-in object-contain duration-200"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleClose();
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.stopPropagation();
|
||||
handleClose();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<NextImage
|
||||
data-slot="image-zoom-content"
|
||||
fill
|
||||
src={src}
|
||||
alt={alt}
|
||||
sizes="90vw"
|
||||
className="aui-image-zoom-content fade-in zoom-in-95 object-contain duration-200"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleClose();
|
||||
}}
|
||||
unoptimized={false}
|
||||
/>
|
||||
)}
|
||||
data-slot="image-zoom-content"
|
||||
fill
|
||||
src={src}
|
||||
alt={alt}
|
||||
sizes="90vw"
|
||||
className="aui-image-zoom-content fade-in zoom-in-95 object-contain duration-200"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleClose();
|
||||
}}
|
||||
unoptimized={false}
|
||||
/>
|
||||
)}
|
||||
</button>,
|
||||
document.body
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ export interface MentionedDocument {
|
|||
export interface InlineMentionEditorRef {
|
||||
focus: () => void;
|
||||
clear: () => void;
|
||||
setText: (text: string) => void;
|
||||
getText: () => string;
|
||||
getMentionedDocuments: () => MentionedDocument[];
|
||||
insertDocumentChip: (doc: Pick<Document, "id" | "title" | "document_type">) => void;
|
||||
|
|
@ -397,6 +398,19 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
}
|
||||
}, []);
|
||||
|
||||
// Replace editor content with plain text and place cursor at end
|
||||
const setText = useCallback(
|
||||
(text: string) => {
|
||||
if (!editorRef.current) return;
|
||||
editorRef.current.innerText = text;
|
||||
const empty = text.length === 0;
|
||||
setIsEmpty(empty);
|
||||
onChange?.(text, Array.from(mentionedDocs.values()));
|
||||
focusAtEnd();
|
||||
},
|
||||
[focusAtEnd, onChange, mentionedDocs]
|
||||
);
|
||||
|
||||
const setDocumentChipStatus = useCallback(
|
||||
(
|
||||
docId: number,
|
||||
|
|
@ -469,6 +483,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
|
|||
useImperativeHandle(ref, () => ({
|
||||
focus: () => editorRef.current?.focus(),
|
||||
clear,
|
||||
setText,
|
||||
getText,
|
||||
getMentionedDocuments,
|
||||
insertDocumentChip,
|
||||
|
|
|
|||
|
|
@ -241,9 +241,7 @@ const ThreadListItemComponent = memo(function ThreadListItemComponent({
|
|||
<MessageSquareIcon className="size-4 shrink-0 text-muted-foreground" />
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="truncate text-sm font-medium">{thread.title || "New Chat"}</p>
|
||||
<p className="truncate text-xs text-muted-foreground">
|
||||
{relativeTime}
|
||||
</p>
|
||||
<p className="truncate text-xs text-muted-foreground">{relativeTime}</p>
|
||||
</div>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
|
|
|
|||
|
|
@ -89,17 +89,10 @@ import type { Document } from "@/contracts/types/document.types";
|
|||
import { useBatchCommentsPreload } from "@/hooks/use-comments";
|
||||
import { useCommentsSync } from "@/hooks/use-comments-sync";
|
||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
/** Placeholder texts that cycle in new chats when input is empty */
|
||||
const CYCLING_PLACEHOLDERS = [
|
||||
"Ask SurfSense anything or @mention docs",
|
||||
"Generate a podcast from my vacation ideas in Notion",
|
||||
"Sum up last week's meeting notes from Drive in a bulleted list",
|
||||
"Give me a brief overview of the most urgent tickets in Jira and Linear",
|
||||
"Briefly, what are today's top ten important emails and calendar events?",
|
||||
"Check if this week's Slack messages reference any GitHub issues",
|
||||
];
|
||||
const COMPOSER_PLACEHOLDER = "Ask anything · Type / for prompts · Type @ to mention docs";
|
||||
|
||||
export const Thread: FC = () => {
|
||||
return <ThreadContent />;
|
||||
|
|
@ -362,45 +355,23 @@ const Composer: FC = () => {
|
|||
};
|
||||
}, []);
|
||||
|
||||
const electronAPI = useElectronAPI();
|
||||
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
|
||||
const clipboardLoadedRef = useRef(false);
|
||||
useEffect(() => {
|
||||
if (!window.electronAPI || clipboardLoadedRef.current) return;
|
||||
if (!electronAPI || clipboardLoadedRef.current) return;
|
||||
clipboardLoadedRef.current = true;
|
||||
window.electronAPI.getQuickAskText().then((text) => {
|
||||
electronAPI.getQuickAskText().then((text) => {
|
||||
if (text) {
|
||||
setClipboardInitialText(text);
|
||||
setShowPromptPicker(true);
|
||||
}
|
||||
});
|
||||
}, []);
|
||||
}, [electronAPI]);
|
||||
|
||||
const isThreadEmpty = useAuiState(({ thread }) => thread.isEmpty);
|
||||
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
|
||||
|
||||
// Cycling placeholder state - only cycles in new chats
|
||||
const [placeholderIndex, setPlaceholderIndex] = useState(0);
|
||||
|
||||
// Cycle through placeholders every 4 seconds when thread is empty (new chat)
|
||||
useEffect(() => {
|
||||
// Only cycle when thread is empty (new chat)
|
||||
if (!isThreadEmpty) {
|
||||
// Reset to first placeholder when chat becomes active
|
||||
setPlaceholderIndex(0);
|
||||
return;
|
||||
}
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
setPlaceholderIndex((prev) => (prev + 1) % CYCLING_PLACEHOLDERS.length);
|
||||
}, 6000);
|
||||
|
||||
return () => clearInterval(intervalId);
|
||||
}, [isThreadEmpty]);
|
||||
|
||||
// Compute current placeholder - only cycle in new chats
|
||||
const currentPlaceholder = isThreadEmpty
|
||||
? CYCLING_PLACEHOLDERS[placeholderIndex]
|
||||
: CYCLING_PLACEHOLDERS[0];
|
||||
const currentPlaceholder = COMPOSER_PLACEHOLDER;
|
||||
|
||||
// Live collaboration state
|
||||
const { data: currentUser } = useAtomValue(currentUserAtom);
|
||||
|
|
@ -504,34 +475,28 @@ const Composer: FC = () => {
|
|||
: userText
|
||||
? `${action.prompt}\n\n${userText}`
|
||||
: action.prompt;
|
||||
editorRef.current?.setText(finalPrompt);
|
||||
aui.composer().setText(finalPrompt);
|
||||
aui.composer().send();
|
||||
editorRef.current?.clear();
|
||||
setShowPromptPicker(false);
|
||||
setActionQuery("");
|
||||
setMentionedDocuments([]);
|
||||
setSidebarDocs([]);
|
||||
},
|
||||
[actionQuery, aui, setMentionedDocuments, setSidebarDocs]
|
||||
[actionQuery, aui]
|
||||
);
|
||||
|
||||
const handleQuickAskSelect = useCallback(
|
||||
(action: { name: string; prompt: string; mode: "transform" | "explore" }) => {
|
||||
if (!clipboardInitialText) return;
|
||||
window.electronAPI?.setQuickAskMode(action.mode);
|
||||
electronAPI?.setQuickAskMode(action.mode);
|
||||
const finalPrompt = action.prompt.includes("{selection}")
|
||||
? action.prompt.replace("{selection}", () => clipboardInitialText)
|
||||
: `${action.prompt}\n\n${clipboardInitialText}`;
|
||||
editorRef.current?.setText(finalPrompt);
|
||||
aui.composer().setText(finalPrompt);
|
||||
aui.composer().send();
|
||||
editorRef.current?.clear();
|
||||
setShowPromptPicker(false);
|
||||
setActionQuery("");
|
||||
setClipboardInitialText(undefined);
|
||||
setMentionedDocuments([]);
|
||||
setSidebarDocs([]);
|
||||
},
|
||||
[clipboardInitialText, aui, setMentionedDocuments, setSidebarDocs]
|
||||
[clipboardInitialText, electronAPI, aui]
|
||||
);
|
||||
|
||||
// Keyboard navigation for document/action picker (arrow keys, Enter, Escape)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ export const ToolFallback: ToolCallMessagePartComponent = ({
|
|||
);
|
||||
|
||||
const serializedResult = useMemo(
|
||||
() => (result !== undefined && typeof result !== "string" ? JSON.stringify(result, null, 2) : null),
|
||||
() =>
|
||||
result !== undefined && typeof result !== "string" ? JSON.stringify(result, null, 2) : null,
|
||||
[result]
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { ArrowUp, Send, X } from "lucide-react";
|
||||
import { ArrowUp } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Popover, PopoverAnchor, PopoverContent } from "@/components/ui/popover";
|
||||
|
|
@ -307,7 +307,6 @@ export function CommentComposer({
|
|||
onClick={onCancel}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
<X className="mr-1 size-4" />
|
||||
Cancel
|
||||
</Button>
|
||||
)}
|
||||
|
|
@ -318,14 +317,7 @@ export function CommentComposer({
|
|||
disabled={!canSubmit}
|
||||
className={cn(!canSubmit && "opacity-50", compact && "size-8 shrink-0 rounded-full")}
|
||||
>
|
||||
{compact ? (
|
||||
<ArrowUp className="size-4" />
|
||||
) : (
|
||||
<>
|
||||
<Send className="mr-1 size-4" />
|
||||
{submitLabel}
|
||||
</>
|
||||
)}
|
||||
{compact ? <ArrowUp className="size-4" /> : submitLabel}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue