Merge upstream/dev

This commit is contained in:
CREDO23 2026-04-27 22:44:40 +02:00
commit 2d962f6dd2
107 changed files with 15033 additions and 2277 deletions

View file

@ -0,0 +1,106 @@
"""129_obsidian_plugin_vault_identity
Revision ID: 129
Revises: 128
Create Date: 2026-04-21
Locks down vault identity for the Obsidian plugin connector:
- Deactivates pre-plugin OBSIDIAN_CONNECTOR rows.
- Partial unique index on ``(user_id, (config->>'vault_id'))`` for the
``/obsidian/connect`` upsert fast path.
- Partial unique index on ``(user_id, (config->>'vault_fingerprint'))``
so two devices observing the same vault content can never produce
two connector rows. Collisions are caught by the route handler and
routed through the merge path.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "129"
down_revision: str | None = "128"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
conn = op.get_bind()
conn.execute(
sa.text(
"""
UPDATE search_source_connectors
SET
is_indexable = false,
periodic_indexing_enabled = false,
next_scheduled_at = NULL,
config = COALESCE(config, '{}'::json)::jsonb
|| jsonb_build_object(
'legacy', true,
'deactivated_at', to_char(
now() AT TIME ZONE 'UTC',
'YYYY-MM-DD"T"HH24:MI:SS"Z"'
)
)
WHERE connector_type = 'OBSIDIAN_CONNECTOR'
AND COALESCE((config::jsonb)->>'source', '') <> 'plugin'
"""
)
)
conn.execute(
sa.text(
"""
CREATE UNIQUE INDEX IF NOT EXISTS
search_source_connectors_obsidian_plugin_vault_uniq
ON search_source_connectors (user_id, ((config->>'vault_id')))
WHERE connector_type = 'OBSIDIAN_CONNECTOR'
AND config->>'source' = 'plugin'
AND config->>'vault_id' IS NOT NULL
"""
)
)
conn.execute(
sa.text(
"""
CREATE UNIQUE INDEX IF NOT EXISTS
search_source_connectors_obsidian_plugin_fingerprint_uniq
ON search_source_connectors (user_id, ((config->>'vault_fingerprint')))
WHERE connector_type = 'OBSIDIAN_CONNECTOR'
AND config->>'source' = 'plugin'
AND config->>'vault_fingerprint' IS NOT NULL
"""
)
)
def downgrade() -> None:
conn = op.get_bind()
conn.execute(
sa.text(
"DROP INDEX IF EXISTS "
"search_source_connectors_obsidian_plugin_fingerprint_uniq"
)
)
conn.execute(
sa.text(
"DROP INDEX IF EXISTS search_source_connectors_obsidian_plugin_vault_uniq"
)
)
conn.execute(
sa.text(
"""
UPDATE search_source_connectors
SET config = (config::jsonb - 'legacy' - 'deactivated_at')::json
WHERE connector_type = 'OBSIDIAN_CONNECTOR'
AND (config::jsonb) ? 'legacy'
"""
)
)

View file

@ -109,37 +109,6 @@ def _sanitize_path_segment(value: str) -> str:
return segment
def _infer_text_file_extension(user_text: str) -> str:
lowered = user_text.lower()
if any(token in lowered for token in ("json", ".json")):
return ".json"
if any(token in lowered for token in ("yaml", "yml", ".yaml", ".yml")):
return ".yaml"
if any(token in lowered for token in ("csv", ".csv")):
return ".csv"
if any(token in lowered for token in ("python", ".py")):
return ".py"
if any(token in lowered for token in ("typescript", ".ts", ".tsx")):
return ".ts"
if any(token in lowered for token in ("javascript", ".js", ".mjs", ".cjs")):
return ".js"
if any(token in lowered for token in ("html", ".html")):
return ".html"
if any(token in lowered for token in ("css", ".css")):
return ".css"
if any(token in lowered for token in ("sql", ".sql")):
return ".sql"
if any(token in lowered for token in ("toml", ".toml")):
return ".toml"
if any(token in lowered for token in ("ini", ".ini")):
return ".ini"
if any(token in lowered for token in ("xml", ".xml")):
return ".xml"
if any(token in lowered for token in ("markdown", ".md", "readme")):
return ".md"
return ".md"
def _normalize_directory(value: str) -> str:
raw = value.strip().replace("\\", "/")
raw = raw.strip("/")
@ -193,7 +162,6 @@ def _fallback_path(
suggested_path: str | None = None,
user_text: str,
) -> str:
default_extension = _infer_text_file_extension(user_text)
inferred_dir = _infer_directory_from_user_text(user_text)
sanitized_filename = ""
@ -202,9 +170,9 @@ def _fallback_path(
if sanitized_filename.lower().endswith(".txt"):
sanitized_filename = f"{sanitized_filename[:-4]}.md"
if not sanitized_filename:
sanitized_filename = f"notes{default_extension}"
sanitized_filename = "notes.md"
elif "." not in sanitized_filename:
sanitized_filename = f"{sanitized_filename}{default_extension}"
sanitized_filename = f"{sanitized_filename}.md"
normalized_suggested_path = (
_normalize_file_path(suggested_path) if suggested_path else ""

View file

@ -7,6 +7,7 @@ This middleware customizes prompts and persists write/edit operations for
from __future__ import annotations
import asyncio
import json
import logging
import re
import secrets
@ -141,6 +142,31 @@ IMPORTANT:
content.
"""
SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder.
Use absolute paths for both source and destination.
Notes:
- In local-folder mode, paths should use mount prefixes (e.g., /<mount>/foo.txt).
- Rename is a special case of move (same folder, different filename).
- Cross-mount moves are not supported.
"""
SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call.
Use this in desktop local-folder mode to discover nested files at scale.
Args:
- path: absolute mount-prefixed path (e.g., /<mount>/src) or "/" for mount roots.
- max_depth: recursion depth limit (default 8).
- page_size: maximum number of entries returned (max 1000).
- include_files/include_dirs: filter returned entry types.
Returns JSON with:
- entries: [{path, is_dir, size, modified_at, depth}]
- truncated: true when additional entries were omitted due to page_size
"""
SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern.
Supports standard glob patterns: `*`, `**`, `?`.
@ -222,11 +248,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
)
if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
system_prompt += (
"\n- move_file: move or rename files/folders in local-folder mode."
"\n- list_tree: recursively list nested local paths in one bounded response."
"\n\n## Local Folder Mode"
"\n\nThis chat is running in desktop local-folder mode."
" Keep all file operations local. Do not use save_document."
" Always use mount-prefixed absolute paths like /<folder>/file.ext."
" If you are unsure which mounts are available, call ls('/') first."
" For big trees: use list_tree, then grep, then read_file."
)
super().__init__(
@ -237,6 +266,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION,
"write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION,
"edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION,
"move_file": SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION,
"list_tree": SURFSENSE_LIST_TREE_TOOL_DESCRIPTION,
"glob": SURFSENSE_GLOB_TOOL_DESCRIPTION,
"grep": SURFSENSE_GREP_TOOL_DESCRIPTION,
},
@ -244,6 +275,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
max_execute_timeout=self._MAX_EXECUTE_TIMEOUT,
)
self.tools = [t for t in self.tools if t.name != "execute"]
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
self.tools.append(self._create_move_file_tool())
self.tools.append(self._create_list_tree_tool())
if self._should_persist_documents():
self.tools.append(self._create_save_document_tool())
if self._sandbox_available:
@ -776,35 +810,97 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""Only cloud mode persists file content to Document/Chunk tables."""
return self._filesystem_mode == FilesystemMode.CLOUD
def _default_mount_prefix(self, runtime: ToolRuntime[None, FilesystemState]) -> str:
backend = self._get_backend(runtime)
if isinstance(backend, MultiRootLocalFolderBackend):
return f"/{backend.default_mount()}"
return ""
@staticmethod
def _normalize_absolute_path(candidate: str) -> str:
normalized = re.sub(r"/+", "/", candidate.strip().replace("\\", "/"))
if not normalized:
return "/"
if normalized.startswith("/"):
return normalized
return f"/{normalized.lstrip('/')}"
@staticmethod
def _extract_mount_from_path(path: str, mounts: tuple[str, ...]) -> str | None:
rel = path.lstrip("/")
if not rel:
return None
mount, _, _ = rel.partition("/")
if mount in mounts:
return mount
return None
@staticmethod
def _local_parent_path(path: str) -> str:
rel = path.lstrip("/")
if "/" not in rel:
return "/"
parent = rel.rsplit("/", 1)[0].strip("/")
if not parent:
return "/"
return f"/{parent}"
@staticmethod
def _path_exists_under_mount(
backend: MultiRootLocalFolderBackend,
mount: str,
local_path: str,
) -> bool:
result = backend.list_tree(
f"/{mount}{local_path}",
max_depth=0,
page_size=1,
include_files=True,
include_dirs=True,
)
return not bool(result.get("error"))
def _normalize_local_mount_path(
self, candidate: str, runtime: ToolRuntime[None, FilesystemState]
self,
candidate: str,
runtime: ToolRuntime[None, FilesystemState],
) -> str:
normalized = self._normalize_absolute_path(candidate)
backend = self._get_backend(runtime)
mount_prefix = self._default_mount_prefix(runtime)
normalized_candidate = re.sub(r"/+", "/", candidate.strip().replace("\\", "/"))
if not mount_prefix or not isinstance(backend, MultiRootLocalFolderBackend):
if normalized_candidate.startswith("/"):
return normalized_candidate
return f"/{normalized_candidate.lstrip('/')}"
if not isinstance(backend, MultiRootLocalFolderBackend):
return normalized
mount_names = set(backend.list_mounts())
if normalized_candidate.startswith("/"):
first_segment = normalized_candidate.lstrip("/").split("/", 1)[0]
if first_segment in mount_names:
return normalized_candidate
return f"{mount_prefix}{normalized_candidate}"
mounts = backend.list_mounts()
explicit_mount = self._extract_mount_from_path(normalized, mounts)
if explicit_mount:
return normalized
relative = normalized_candidate.lstrip("/")
first_segment = relative.split("/", 1)[0]
if first_segment in mount_names:
return f"/{relative}"
return f"{mount_prefix}/{relative}"
if len(mounts) == 1:
return f"/{mounts[0]}{normalized}"
suggested_mount: str | None = None
contract = runtime.state.get("file_operation_contract") or {}
suggested_path = contract.get("suggested_path")
if isinstance(suggested_path, str) and suggested_path.strip():
normalized_suggested = self._normalize_absolute_path(suggested_path)
suggested_mount = self._extract_mount_from_path(normalized_suggested, mounts)
matching_mounts = [
mount
for mount in mounts
if self._path_exists_under_mount(backend, mount, normalized)
]
if len(matching_mounts) == 1:
return f"/{matching_mounts[0]}{normalized}"
parent_path = self._local_parent_path(normalized)
if parent_path != "/":
parent_matching_mounts = [
mount
for mount in mounts
if self._path_exists_under_mount(backend, mount, parent_path)
]
if len(parent_matching_mounts) == 1:
return f"/{parent_matching_mounts[0]}{normalized}"
if suggested_mount:
return f"/{suggested_mount}{normalized}"
return f"/{backend.default_mount()}{normalized}"
def _get_contract_suggested_path(
self, runtime: ToolRuntime[None, FilesystemState]
@ -812,14 +908,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
contract = runtime.state.get("file_operation_contract") or {}
suggested = contract.get("suggested_path")
if isinstance(suggested, str) and suggested.strip():
cleaned = suggested.strip()
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
return self._normalize_local_mount_path(cleaned, runtime)
return cleaned
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
mount_prefix = self._default_mount_prefix(runtime)
if mount_prefix:
return f"{mount_prefix}/notes.md"
return self._normalize_absolute_path(suggested)
return "/notes.md"
def _resolve_write_target_path(
@ -836,6 +925,34 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
return f"/{candidate.lstrip('/')}"
return candidate
def _resolve_move_target_path(
self,
file_path: str,
runtime: ToolRuntime[None, FilesystemState],
) -> str:
candidate = file_path.strip()
if not candidate:
return ""
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
return self._normalize_local_mount_path(candidate, runtime)
if not candidate.startswith("/"):
return f"/{candidate.lstrip('/')}"
return candidate
def _resolve_list_target_path(
self,
path: str,
runtime: ToolRuntime[None, FilesystemState],
) -> str:
candidate = path.strip() or "/"
if candidate == "/":
return "/"
if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
return self._normalize_local_mount_path(candidate, runtime)
if not candidate.startswith("/"):
return f"/{candidate.lstrip('/')}"
return candidate
@staticmethod
def _is_error_text(value: str) -> bool:
return value.startswith("Error:")
@ -930,6 +1047,246 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
)
return None, updated_content
def _create_move_file_tool(self) -> BaseTool:
"""Create move_file for desktop local-folder mode."""
tool_description = (
self._custom_tool_descriptions.get("move_file")
or SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION
)
def sync_move_file(
source_path: Annotated[
str,
"Absolute source path to move from.",
],
destination_path: Annotated[
str,
"Absolute destination path to move to.",
],
runtime: ToolRuntime[None, FilesystemState],
*,
overwrite: Annotated[
bool,
"If True, replace an existing destination file. Defaults to False.",
] = False,
) -> Command | str:
if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER:
return "Error: move_file is only available in desktop local-folder mode."
if not source_path.strip() or not destination_path.strip():
return "Error: source_path and destination_path are required."
resolved_backend = self._get_backend(runtime)
source_target = self._resolve_move_target_path(source_path, runtime)
destination_target = self._resolve_move_target_path(destination_path, runtime)
try:
validated_source = validate_path(source_target)
validated_destination = validate_path(destination_target)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = resolved_backend.move(
validated_source,
validated_destination,
overwrite=overwrite,
)
if res.error:
return res.error
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=(
f"Moved '{validated_source}' to "
f"'{res.path or validated_destination}'"
),
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Moved '{validated_source}' to '{res.path or validated_destination}'"
async def async_move_file(
source_path: Annotated[
str,
"Absolute source path to move from.",
],
destination_path: Annotated[
str,
"Absolute destination path to move to.",
],
runtime: ToolRuntime[None, FilesystemState],
*,
overwrite: Annotated[
bool,
"If True, replace an existing destination file. Defaults to False.",
] = False,
) -> Command | str:
if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER:
return "Error: move_file is only available in desktop local-folder mode."
if not source_path.strip() or not destination_path.strip():
return "Error: source_path and destination_path are required."
resolved_backend = self._get_backend(runtime)
source_target = self._resolve_move_target_path(source_path, runtime)
destination_target = self._resolve_move_target_path(destination_path, runtime)
try:
validated_source = validate_path(source_target)
validated_destination = validate_path(destination_target)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = await resolved_backend.amove(
validated_source,
validated_destination,
overwrite=overwrite,
)
if res.error:
return res.error
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=(
f"Moved '{validated_source}' to "
f"'{res.path or validated_destination}'"
),
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Moved '{validated_source}' to '{res.path or validated_destination}'"
return StructuredTool.from_function(
name="move_file",
description=tool_description,
func=sync_move_file,
coroutine=async_move_file,
)
def _create_list_tree_tool(self) -> BaseTool:
"""Create list_tree for desktop local-folder mode."""
tool_description = (
self._custom_tool_descriptions.get("list_tree")
or SURFSENSE_LIST_TREE_TOOL_DESCRIPTION
)
def sync_list_tree(
runtime: ToolRuntime[None, FilesystemState],
*,
path: Annotated[
str,
"Absolute path to list from. Use '/' for mount roots.",
] = "/",
max_depth: Annotated[
int,
"Maximum recursion depth to traverse. Defaults to 8.",
] = 8,
page_size: Annotated[
int,
"Maximum number of entries to return. Defaults to 500 (max 1000).",
] = 500,
include_files: Annotated[
bool,
"Whether file entries should be included.",
] = True,
include_dirs: Annotated[
bool,
"Whether directory entries should be included.",
] = True,
) -> str:
if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER:
return "Error: list_tree is only available in desktop local-folder mode."
if max_depth < 0:
return "Error: max_depth must be >= 0."
if page_size < 1:
return "Error: page_size must be >= 1."
if not include_files and not include_dirs:
return "Error: include_files and include_dirs cannot both be false."
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_list_target_path(path, runtime)
try:
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
result = resolved_backend.list_tree(
validated_path,
max_depth=max_depth,
page_size=page_size,
include_files=include_files,
include_dirs=include_dirs,
)
error = result.get("error") if isinstance(result, dict) else None
if isinstance(error, str) and error:
return error
return json.dumps(result, ensure_ascii=True)
async def async_list_tree(
runtime: ToolRuntime[None, FilesystemState],
*,
path: Annotated[
str,
"Absolute path to list from. Use '/' for mount roots.",
] = "/",
max_depth: Annotated[
int,
"Maximum recursion depth to traverse. Defaults to 8.",
] = 8,
page_size: Annotated[
int,
"Maximum number of entries to return. Defaults to 500 (max 1000).",
] = 500,
include_files: Annotated[
bool,
"Whether file entries should be included.",
] = True,
include_dirs: Annotated[
bool,
"Whether directory entries should be included.",
] = True,
) -> str:
if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER:
return "Error: list_tree is only available in desktop local-folder mode."
if max_depth < 0:
return "Error: max_depth must be >= 0."
if page_size < 1:
return "Error: page_size must be >= 1."
if not include_files and not include_dirs:
return "Error: include_files and include_dirs cannot both be false."
resolved_backend = self._get_backend(runtime)
target_path = self._resolve_list_target_path(path, runtime)
try:
validated_path = validate_path(target_path)
except ValueError as exc:
return f"Error: {exc}"
result = await resolved_backend.alist_tree(
validated_path,
max_depth=max_depth,
page_size=page_size,
include_files=include_files,
include_dirs=include_dirs,
)
error = result.get("error") if isinstance(result, dict) else None
if isinstance(error, str) and error:
return error
return json.dumps(result, ensure_ascii=True)
return StructuredTool.from_function(
name="list_tree",
description=tool_description,
func=sync_list_tree,
coroutine=async_list_tree,
)
def _create_edit_file_tool(self) -> BaseTool:
"""Create edit_file with DB persistence (skipped for KB documents)."""
tool_description = (

View file

@ -6,7 +6,10 @@ import asyncio
import fnmatch
import os
import threading
from collections import deque
from contextlib import ExitStack
from pathlib import Path
from typing import Any
from deepagents.backends.protocol import (
EditResult,
@ -71,6 +74,44 @@ class LocalFolderBackend:
temp_path.write_text(content, encoding="utf-8")
os.replace(temp_path, path)
def _acquire_path_locks(self, *paths: str) -> ExitStack:
ordered_paths = sorted(set(paths))
stack = ExitStack()
for path in ordered_paths:
stack.enter_context(self._lock_for(path))
return stack
@staticmethod
def _clamp_page_size(page_size: int) -> int:
return max(1, min(page_size, 1000))
def _read_dir_entries(self, directory_path: str) -> list[dict[str, Any]]:
directory = Path(directory_path)
try:
children = sorted(
directory.iterdir(),
key=lambda p: (not p.is_dir(), p.name.lower()),
)
except OSError:
return []
entries: list[dict[str, Any]] = []
for child in children:
try:
stat_result = child.stat()
except OSError:
continue
entries.append(
{
"path": self._to_virtual(child, self._root),
"is_dir": child.is_dir(),
"size": stat_result.st_size if child.is_file() else 0,
"modified_at": str(stat_result.st_mtime),
"absolute_path": str(child),
}
)
return entries
def ls_info(self, path: str) -> list[FileInfo]:
try:
target = self._resolve_virtual(path, allow_root=True)
@ -139,12 +180,178 @@ class LocalFolderBackend:
"Read and then make an edit, or write to a new path."
)
)
parent = path.parent
if not parent.exists() or not parent.is_dir():
return WriteResult(
error=(
f"Error: parent directory for '{file_path}' does not exist. "
"Create the folder first or write to an existing directory."
)
)
self._write_text_atomic(path, content)
return WriteResult(path=file_path, files_update=None)
async def awrite(self, file_path: str, content: str) -> WriteResult:
return await asyncio.to_thread(self.write, file_path, content)
def list_tree(
self,
path: str = "/",
*,
max_depth: int | None = 8,
page_size: int = 500,
include_files: bool = True,
include_dirs: bool = True,
) -> dict[str, Any]:
if not include_files and not include_dirs:
return {
"entries": [],
"truncated": False,
}
normalized_depth = None if max_depth is None else max(0, int(max_depth))
page_limit = self._clamp_page_size(int(page_size))
try:
start = self._resolve_virtual(path, allow_root=True)
except ValueError:
return {"error": f"Error: invalid path '{path}'"}
if not start.exists():
return {"error": f"Error: path '{path}' not found"}
if start.is_file():
stat_result = start.stat()
if include_files:
return {
"entries": [
{
"path": self._to_virtual(start, self._root),
"is_dir": False,
"size": stat_result.st_size,
"modified_at": str(stat_result.st_mtime),
"depth": 0,
}
],
"truncated": False,
}
return {
"entries": [],
"truncated": False,
}
pending_dirs: deque[tuple[str, int]] = deque([(str(start), 0)])
entries: list[dict[str, Any]] = []
truncated = False
while pending_dirs and not truncated:
next_dir_path, next_depth = pending_dirs.popleft()
active_entries = self._read_dir_entries(next_dir_path)
for item in active_entries:
item_depth = next_depth + 1
if normalized_depth is not None and item_depth > normalized_depth:
continue
if item["is_dir"]:
if normalized_depth is None or item_depth <= normalized_depth:
pending_dirs.append((item["absolute_path"], item_depth))
if include_dirs:
entries.append(
{
"path": item["path"],
"is_dir": True,
"size": 0,
"modified_at": item["modified_at"],
"depth": item_depth,
}
)
elif include_files:
entries.append(
{
"path": item["path"],
"is_dir": False,
"size": item["size"],
"modified_at": item["modified_at"],
"depth": item_depth,
}
)
if len(entries) >= page_limit:
truncated = True
break
return {
"entries": entries,
"truncated": truncated,
}
async def alist_tree(
self,
path: str = "/",
*,
max_depth: int | None = 8,
page_size: int = 500,
include_files: bool = True,
include_dirs: bool = True,
) -> dict[str, Any]:
return await asyncio.to_thread(
self.list_tree,
path,
max_depth=max_depth,
page_size=page_size,
include_files=include_files,
include_dirs=include_dirs,
)
def move(
self,
source_path: str,
destination_path: str,
overwrite: bool = False,
) -> WriteResult:
try:
source = self._resolve_virtual(source_path)
destination = self._resolve_virtual(destination_path)
except ValueError:
return WriteResult(
error=(
f"Error: invalid source '{source_path}' or destination "
f"'{destination_path}' path"
)
)
if source == destination:
return WriteResult(error="Error: source and destination paths are the same")
with self._acquire_path_locks(source_path, destination_path):
if not source.exists():
return WriteResult(error=f"Error: source path '{source_path}' not found")
if destination.exists():
if not overwrite:
return WriteResult(
error=(
f"Error: destination path '{destination_path}' already exists. "
"Set overwrite=True to replace files."
)
)
if source.is_dir() or destination.is_dir():
return WriteResult(
error=(
"Error: overwrite=True is only supported for file-to-file moves."
)
)
destination.parent.mkdir(parents=True, exist_ok=True)
try:
if overwrite:
os.replace(source, destination)
else:
source.rename(destination)
except OSError as exc:
return WriteResult(error=f"Error: failed to move '{source_path}': {exc}")
return WriteResult(path=self._to_virtual(destination, self._root), files_update=None)
async def amove(
self,
source_path: str,
destination_path: str,
overwrite: bool = False,
) -> WriteResult:
return await asyncio.to_thread(
self.move, source_path, destination_path, overwrite
)
def edit(
self,
file_path: str,

View file

@ -132,6 +132,82 @@ class MultiRootLocalFolderBackend:
async def als_info(self, path: str) -> list[FileInfo]:
return await asyncio.to_thread(self.ls_info, path)
def list_tree(
self,
path: str = "/",
*,
max_depth: int | None = 8,
page_size: int = 500,
include_files: bool = True,
include_dirs: bool = True,
) -> dict[str, Any]:
if path == "/":
entries = [
{
"path": f"/{mount}",
"is_dir": True,
"size": 0,
"modified_at": "0",
"depth": 0,
}
for mount in self._mount_order
]
return {
"entries": entries if include_dirs else [],
"truncated": False,
}
try:
mount, local_path = self._split_mount_path(path)
except ValueError as exc:
return {"error": f"Error: {exc}"}
result = self._mount_to_backend[mount].list_tree(
local_path,
max_depth=max_depth,
page_size=page_size,
include_files=include_files,
include_dirs=include_dirs,
)
if result.get("error"):
return result
entries: list[dict[str, Any]] = []
for entry in result.get("entries", []):
raw_path = self._get_str(entry, "path")
entries.append(
{
"path": self._prefix_mount_path(mount, raw_path),
"is_dir": self._get_bool(entry, "is_dir"),
"size": self._get_int(entry, "size"),
"modified_at": self._get_str(entry, "modified_at"),
"depth": self._get_int(entry, "depth"),
}
)
return {
"entries": entries,
"truncated": self._get_bool(result, "truncated"),
}
async def alist_tree(
self,
path: str = "/",
*,
max_depth: int | None = 8,
page_size: int = 500,
include_files: bool = True,
include_dirs: bool = True,
) -> dict[str, Any]:
return await asyncio.to_thread(
self.list_tree,
path,
max_depth=max_depth,
page_size=page_size,
include_files=include_files,
include_dirs=include_dirs,
)
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str:
try:
mount, local_path = self._split_mount_path(file_path)
@ -165,6 +241,48 @@ class MultiRootLocalFolderBackend:
async def awrite(self, file_path: str, content: str) -> WriteResult:
return await asyncio.to_thread(self.write, file_path, content)
def move(
self,
source_path: str,
destination_path: str,
overwrite: bool = False,
) -> WriteResult:
try:
source_mount, source_local_path = self._split_mount_path(source_path)
destination_mount, destination_local_path = self._split_mount_path(
destination_path
)
except ValueError as exc:
return WriteResult(error=f"Error: {exc}")
if source_mount != destination_mount:
return WriteResult(
error=(
"Error: cross-mount moves are not supported. "
"Source and destination must be under the same mounted root."
)
)
result = self._mount_to_backend[source_mount].move(
source_local_path,
destination_local_path,
overwrite=overwrite,
)
if result.path:
result.path = self._prefix_mount_path(source_mount, result.path)
return result
async def amove(
self,
source_path: str,
destination_path: str,
overwrite: bool = False,
) -> WriteResult:
return await asyncio.to_thread(
self.move,
source_path,
destination_path,
overwrite,
)
def edit(
self,
file_path: str,

View file

@ -90,6 +90,7 @@ celery_app = Celery(
"app.tasks.celery_tasks.podcast_tasks",
"app.tasks.celery_tasks.video_presentation_tasks",
"app.tasks.celery_tasks.connector_tasks",
"app.tasks.celery_tasks.obsidian_tasks",
"app.tasks.celery_tasks.schedule_checker_task",
"app.tasks.celery_tasks.document_reindex_tasks",
"app.tasks.celery_tasks.stale_notification_cleanup_task",
@ -144,8 +145,8 @@ celery_app.conf.update(
"index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE},
"index_crawled_urls": {"queue": CONNECTORS_QUEUE},
"index_bookstack_pages": {"queue": CONNECTORS_QUEUE},
"index_obsidian_vault": {"queue": CONNECTORS_QUEUE},
"index_composio_connector": {"queue": CONNECTORS_QUEUE},
"index_obsidian_attachment": {"queue": CONNECTORS_QUEUE},
# Everything else (document processing, podcasts, reindexing,
# schedule checker, cleanup) stays on the default fast queue.
},

View file

@ -1510,6 +1510,31 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
"name",
name="uq_searchspace_user_connector_type_name",
),
# Mirrors migration 129; backs the ``/obsidian/connect`` upsert.
Index(
"search_source_connectors_obsidian_plugin_vault_uniq",
"user_id",
text("(config->>'vault_id')"),
unique=True,
postgresql_where=text(
"connector_type = 'OBSIDIAN_CONNECTOR' "
"AND config->>'source' = 'plugin' "
"AND config->>'vault_id' IS NOT NULL"
),
),
# Cross-device dedup: same vault content from different devices
# cannot produce two connector rows.
Index(
"search_source_connectors_obsidian_plugin_fingerprint_uniq",
"user_id",
text("(config->>'vault_fingerprint')"),
unique=True,
postgresql_where=text(
"connector_type = 'OBSIDIAN_CONNECTOR' "
"AND config->>'source' = 'plugin' "
"AND config->>'vault_fingerprint' IS NOT NULL"
),
),
)
name = Column(String(100), nullable=False, index=True)

View file

@ -37,6 +37,7 @@ from .new_llm_config_routes import router as new_llm_config_router
from .notes_routes import router as notes_router
from .notifications_routes import router as notifications_router
from .notion_add_connector_route import router as notion_add_connector_router
from .obsidian_plugin_routes import router as obsidian_plugin_router
from .onedrive_add_connector_route import router as onedrive_add_connector_router
from .podcasts_routes import router as podcasts_router
from .prompts_routes import router as prompts_router
@ -84,6 +85,7 @@ router.include_router(notion_add_connector_router)
router.include_router(slack_add_connector_router)
router.include_router(teams_add_connector_router)
router.include_router(onedrive_add_connector_router)
router.include_router(obsidian_plugin_router) # Obsidian plugin push API
router.include_router(discord_add_connector_router)
router.include_router(jira_add_connector_router)
router.include_router(confluence_add_connector_router)

View file

@ -0,0 +1,706 @@
"""Obsidian plugin ingestion routes (``/api/v1/obsidian/*``).
Wire surface for the ``surfsense_obsidian/`` plugin. Versioning anchor is
the ``/api/v1/`` URL prefix; additive feature detection rides the
``capabilities`` array on /health and /connect.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import and_, case, func
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
User,
get_async_session,
)
from app.schemas.obsidian_plugin import (
ALLOWED_ATTACHMENT_EXTENSIONS,
ATTACHMENT_MIME_TYPES,
ConnectRequest,
ConnectResponse,
DeleteAck,
DeleteAckItem,
DeleteBatchRequest,
HealthResponse,
ManifestResponse,
RenameAck,
RenameAckItem,
RenameBatchRequest,
StatsResponse,
SyncAck,
SyncAckItem,
SyncBatchRequest,
)
from app.services.notification_service import NotificationService
from app.services.obsidian_plugin_indexer import (
delete_note,
get_manifest,
merge_obsidian_connectors,
rename_note,
upsert_note,
)
from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task
from app.users import current_active_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/obsidian", tags=["obsidian-plugin"])
# Plugins feature-gate on these. Add entries, never rename or remove.
OBSIDIAN_CAPABILITIES: list[str] = ["sync", "rename", "delete", "manifest", "stats"]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _build_handshake() -> dict[str, object]:
return {"capabilities": list(OBSIDIAN_CAPABILITIES)}
def _connector_type_value(connector: SearchSourceConnector) -> str:
connector_type = connector.connector_type
if hasattr(connector_type, "value"):
return str(connector_type.value)
return str(connector_type)
async def _start_obsidian_sync_notification(
session: AsyncSession,
*,
user: User,
connector: SearchSourceConnector,
total_count: int,
):
"""Create/update the rolling inbox item for Obsidian plugin sync.
Obsidian sync is continuous and batched, so we keep one stable
operation_id per connector instead of creating a new notification per batch.
"""
handler = NotificationService.connector_indexing
operation_id = f"obsidian_sync_connector_{connector.id}"
connector_name = connector.name or "Obsidian"
notification = await handler.find_or_create_notification(
session=session,
user_id=user.id,
operation_id=operation_id,
title=f"Syncing: {connector_name}",
message="Syncing from Obsidian plugin",
search_space_id=connector.search_space_id,
initial_metadata={
"connector_id": connector.id,
"connector_name": connector_name,
"connector_type": _connector_type_value(connector),
"sync_stage": "processing",
"indexed_count": 0,
"failed_count": 0,
"total_count": total_count,
"source": "obsidian_plugin",
},
)
return await handler.update_notification(
session=session,
notification=notification,
status="in_progress",
metadata_updates={
"sync_stage": "processing",
"total_count": total_count,
},
)
async def _finish_obsidian_sync_notification(
session: AsyncSession,
*,
notification,
indexed: int,
failed: int,
):
"""Mark the rolling Obsidian sync inbox item complete or failed."""
handler = NotificationService.connector_indexing
connector_name = notification.notification_metadata.get(
"connector_name", "Obsidian"
)
if failed > 0 and indexed == 0:
title = f"Failed: {connector_name}"
message = (
f"Sync failed: {failed} file(s) failed"
if failed > 1
else "Sync failed: 1 file failed"
)
status_value = "failed"
stage = "failed"
else:
title = f"Ready: {connector_name}"
if failed > 0:
message = f"Partially synced: {indexed} file(s) synced, {failed} failed."
elif indexed == 0:
message = "Already up to date!"
elif indexed == 1:
message = "Now searchable! 1 file synced."
else:
message = f"Now searchable! {indexed} files synced."
status_value = "completed"
stage = "completed"
await handler.update_notification(
session=session,
notification=notification,
title=title,
message=message,
status=status_value,
metadata_updates={
"indexed_count": indexed,
"failed_count": failed,
"sync_stage": stage,
},
)
async def _resolve_vault_connector(
session: AsyncSession,
*,
user: User,
vault_id: str,
) -> SearchSourceConnector:
"""Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user."""
# ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the
# cross-dialect equivalent of ``.astext`` and compiles to ``->>``.
stmt = select(SearchSourceConnector).where(
and_(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
SearchSourceConnector.config["vault_id"].as_string() == vault_id,
SearchSourceConnector.config["source"].as_string() == "plugin",
)
)
connector = (await session.execute(stmt)).scalars().first()
if connector is not None:
return connector
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"code": "VAULT_NOT_REGISTERED",
"message": (
"No Obsidian plugin connector found for this vault. "
"Call POST /obsidian/connect first."
),
"vault_id": vault_id,
},
)
def _queue_obsidian_attachment(
*, connector_id: int, note_payload: dict, user_id: str
) -> None:
"""Enqueue one non-markdown Obsidian note for background ETL/indexing."""
index_obsidian_attachment_task.delay(
connector_id=connector_id,
payload_data=note_payload,
user_id=user_id,
)
async def _ensure_search_space_access(
session: AsyncSession,
*,
user: User,
search_space_id: int,
) -> SearchSpace:
"""Owner-only access to the search space (shared spaces are a follow-up)."""
result = await session.execute(
select(SearchSpace).where(
and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id)
)
)
space = result.scalars().first()
if space is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": "SEARCH_SPACE_FORBIDDEN",
"message": "You don't own that search space.",
},
)
return space
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.get("/health", response_model=HealthResponse)
async def obsidian_health(
user: User = Depends(current_active_user),
) -> HealthResponse:
"""Return the API contract handshake; plugin caches it per onload."""
return HealthResponse(
**_build_handshake(),
server_time_utc=datetime.now(UTC),
)
async def _find_by_vault_id(
session: AsyncSession, *, user_id, vault_id: str
) -> SearchSourceConnector | None:
stmt = select(SearchSourceConnector).where(
and_(
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
SearchSourceConnector.config["source"].as_string() == "plugin",
SearchSourceConnector.config["vault_id"].as_string() == vault_id,
)
)
return (await session.execute(stmt)).scalars().first()
async def _find_by_fingerprint(
session: AsyncSession, *, user_id, vault_fingerprint: str
) -> SearchSourceConnector | None:
stmt = select(SearchSourceConnector).where(
and_(
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
SearchSourceConnector.config["source"].as_string() == "plugin",
SearchSourceConnector.config["vault_fingerprint"].as_string()
== vault_fingerprint,
)
)
return (await session.execute(stmt)).scalars().first()
def _build_config(payload: ConnectRequest, *, now_iso: str) -> dict[str, object]:
return {
"vault_id": payload.vault_id,
"vault_name": payload.vault_name,
"vault_fingerprint": payload.vault_fingerprint,
"source": "plugin",
"last_connect_at": now_iso,
}
def _display_name(vault_name: str) -> str:
return f"Obsidian - {vault_name}"
@router.post("/connect", response_model=ConnectResponse)
async def obsidian_connect(
payload: ConnectRequest,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
) -> ConnectResponse:
"""Register a vault, refresh an existing one, or adopt another device's row.
Resolution order:
1. ``(user_id, vault_id)`` known device, refresh metadata.
2. ``(user_id, vault_fingerprint)`` another device of the same vault,
caller adopts the surviving ``vault_id``.
3. Insert a new row.
Fingerprint collisions on (1) trigger ``merge_obsidian_connectors`` so
the partial unique index can never produce two live rows for one vault.
"""
await _ensure_search_space_access(
session, user=user, search_space_id=payload.search_space_id
)
now_iso = datetime.now(UTC).isoformat()
cfg = _build_config(payload, now_iso=now_iso)
display_name = _display_name(payload.vault_name)
existing_by_vid = await _find_by_vault_id(
session, user_id=user.id, vault_id=payload.vault_id
)
if existing_by_vid is not None:
collision = await _find_by_fingerprint(
session, user_id=user.id, vault_fingerprint=payload.vault_fingerprint
)
if collision is not None and collision.id != existing_by_vid.id:
await merge_obsidian_connectors(
session, source=existing_by_vid, target=collision
)
collision_cfg = dict(collision.config or {})
collision_cfg["vault_name"] = payload.vault_name
collision_cfg["last_connect_at"] = now_iso
collision.config = collision_cfg
collision.name = _display_name(payload.vault_name)
response = ConnectResponse(
connector_id=collision.id,
vault_id=collision_cfg["vault_id"],
search_space_id=collision.search_space_id,
server_time_utc=datetime.now(UTC),
**_build_handshake(),
)
await session.commit()
return response
existing_by_vid.name = display_name
existing_by_vid.config = cfg
existing_by_vid.search_space_id = payload.search_space_id
existing_by_vid.is_indexable = False
response = ConnectResponse(
connector_id=existing_by_vid.id,
vault_id=payload.vault_id,
search_space_id=existing_by_vid.search_space_id,
server_time_utc=datetime.now(UTC),
**_build_handshake(),
)
await session.commit()
return response
existing_by_fp = await _find_by_fingerprint(
session, user_id=user.id, vault_fingerprint=payload.vault_fingerprint
)
if existing_by_fp is not None:
survivor_cfg = dict(existing_by_fp.config or {})
survivor_cfg["vault_name"] = payload.vault_name
survivor_cfg["last_connect_at"] = now_iso
existing_by_fp.config = survivor_cfg
existing_by_fp.name = display_name
response = ConnectResponse(
connector_id=existing_by_fp.id,
vault_id=survivor_cfg["vault_id"],
search_space_id=existing_by_fp.search_space_id,
server_time_utc=datetime.now(UTC),
**_build_handshake(),
)
await session.commit()
return response
# ON CONFLICT DO NOTHING matches any unique index (vault_id OR
# fingerprint), so concurrent first-time connects from two devices
# of the same vault never raise IntegrityError — the loser just
# gets an empty RETURNING and falls through to re-fetch the winner.
insert_stmt = (
pg_insert(SearchSourceConnector)
.values(
name=display_name,
connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
is_indexable=False,
config=cfg,
user_id=user.id,
search_space_id=payload.search_space_id,
)
.on_conflict_do_nothing()
.returning(
SearchSourceConnector.id,
SearchSourceConnector.search_space_id,
)
)
inserted = (await session.execute(insert_stmt)).first()
if inserted is not None:
response = ConnectResponse(
connector_id=inserted.id,
vault_id=payload.vault_id,
search_space_id=inserted.search_space_id,
server_time_utc=datetime.now(UTC),
**_build_handshake(),
)
await session.commit()
return response
winner = await _find_by_fingerprint(
session, user_id=user.id, vault_fingerprint=payload.vault_fingerprint
)
if winner is None:
winner = await _find_by_vault_id(
session, user_id=user.id, vault_id=payload.vault_id
)
if winner is None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="vault registration conflicted but winning row could not be located",
)
response = ConnectResponse(
connector_id=winner.id,
vault_id=(winner.config or {})["vault_id"],
search_space_id=winner.search_space_id,
server_time_utc=datetime.now(UTC),
**_build_handshake(),
)
await session.commit()
return response
@router.post("/sync", response_model=SyncAck)
async def obsidian_sync(
payload: SyncBatchRequest,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
) -> SyncAck:
"""Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
)
notification = None
try:
notification = await _start_obsidian_sync_notification(
session, user=user, connector=connector, total_count=len(payload.notes)
)
except Exception:
logger.warning(
"obsidian sync notification start failed connector=%s user=%s",
connector.id,
user.id,
exc_info=True,
)
items: list[SyncAckItem] = []
indexed = 0
failed = 0
for note in payload.notes:
try:
if note.is_binary:
ext = note.extension.lstrip(".").lower()
if ext not in ALLOWED_ATTACHMENT_EXTENSIONS:
failed += 1
items.append(
SyncAckItem(
path=note.path,
status="error",
error=f"unsupported attachment extension: .{ext}",
)
)
continue
expected_mime = ATTACHMENT_MIME_TYPES[ext]
if note.mime_type != expected_mime:
failed += 1
items.append(
SyncAckItem(
path=note.path,
status="error",
error=(
f"mime_type '{note.mime_type}' does not match "
f"extension .{ext}"
),
)
)
continue
_queue_obsidian_attachment(
connector_id=connector.id,
note_payload=note.model_dump(mode="json"),
user_id=str(user.id),
)
indexed += 1
items.append(SyncAckItem(path=note.path, status="queued"))
continue
doc = await upsert_note(
session, connector=connector, payload=note, user_id=str(user.id)
)
indexed += 1
items.append(SyncAckItem(path=note.path, status="ok", document_id=doc.id))
except HTTPException:
raise
except Exception as exc:
failed += 1
logger.exception(
"obsidian /sync failed for path=%s vault=%s",
note.path,
payload.vault_id,
)
items.append(
SyncAckItem(path=note.path, status="error", error=str(exc)[:300])
)
if notification is not None:
try:
await _finish_obsidian_sync_notification(
session,
notification=notification,
indexed=indexed,
failed=failed,
)
except Exception:
logger.warning(
"obsidian sync notification finish failed connector=%s user=%s",
connector.id,
user.id,
exc_info=True,
)
return SyncAck(
vault_id=payload.vault_id,
indexed=indexed,
failed=failed,
items=items,
)
@router.post("/rename", response_model=RenameAck)
async def obsidian_rename(
payload: RenameBatchRequest,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
) -> RenameAck:
"""Apply a batch of vault rename events."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
)
items: list[RenameAckItem] = []
renamed = 0
missing = 0
for item in payload.renames:
try:
doc = await rename_note(
session,
connector=connector,
old_path=item.old_path,
new_path=item.new_path,
vault_id=payload.vault_id,
)
if doc is None:
missing += 1
items.append(
RenameAckItem(
old_path=item.old_path,
new_path=item.new_path,
status="missing",
)
)
else:
renamed += 1
items.append(
RenameAckItem(
old_path=item.old_path,
new_path=item.new_path,
status="ok",
document_id=doc.id,
)
)
except Exception as exc:
logger.exception(
"obsidian /rename failed for old=%s new=%s vault=%s",
item.old_path,
item.new_path,
payload.vault_id,
)
items.append(
RenameAckItem(
old_path=item.old_path,
new_path=item.new_path,
status="error",
error=str(exc)[:300],
)
)
return RenameAck(
vault_id=payload.vault_id,
renamed=renamed,
missing=missing,
items=items,
)
@router.delete("/notes", response_model=DeleteAck)
async def obsidian_delete_notes(
payload: DeleteBatchRequest,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
) -> DeleteAck:
"""Soft-delete a batch of notes by vault-relative path."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
)
deleted = 0
missing = 0
items: list[DeleteAckItem] = []
for path in payload.paths:
try:
ok = await delete_note(
session,
connector=connector,
vault_id=payload.vault_id,
path=path,
)
if ok:
deleted += 1
items.append(DeleteAckItem(path=path, status="ok"))
else:
missing += 1
items.append(DeleteAckItem(path=path, status="missing"))
except Exception as exc:
logger.exception(
"obsidian DELETE /notes failed for path=%s vault=%s",
path,
payload.vault_id,
)
items.append(DeleteAckItem(path=path, status="error", error=str(exc)[:300]))
return DeleteAck(
vault_id=payload.vault_id,
deleted=deleted,
missing=missing,
items=items,
)
@router.get("/manifest", response_model=ManifestResponse)
async def obsidian_manifest(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
) -> ManifestResponse:
"""Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff."""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
return await get_manifest(session, connector=connector, vault_id=vault_id)
@router.get("/stats", response_model=StatsResponse)
async def obsidian_stats(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
) -> StatsResponse:
"""Active-note count + last sync time for the web tile.
``files_synced`` excludes tombstones so it matches ``/manifest``;
``last_sync_at`` includes them so deletes advance the freshness signal.
"""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
is_active = Document.document_metadata["deleted_at"].as_string().is_(None)
row = (
await session.execute(
select(
func.count(case((is_active, 1))).label("files_synced"),
func.max(Document.updated_at).label("last_sync_at"),
).where(
and_(
Document.connector_id == connector.id,
Document.document_type == DocumentType.OBSIDIAN_CONNECTOR,
)
)
)
).first()
return StatsResponse(
vault_id=vault_id,
files_synced=int(row[0] or 0),
last_sync_at=row[1],
)

View file

@ -81,6 +81,36 @@ _heartbeat_redis_client: redis.Redis | None = None
# Redis key TTL - notification is stale if no heartbeat in this time
HEARTBEAT_TTL_SECONDS = 120 # 2 minutes
# How often the background loop refreshes the Redis key. Must be < TTL so
# the key cannot expire between refreshes when the indexing function is
# doing blocking work (e.g. gitingest in Phase 1) that doesn't trigger
# on_heartbeat_callback.
HEARTBEAT_REFRESH_INTERVAL = 60
async def _run_indexing_heartbeat_loop(notification_id: int) -> None:
"""Background coroutine that refreshes the Redis heartbeat every
HEARTBEAT_REFRESH_INTERVAL seconds while a connector indexing task is
running.
Mirrors `_run_heartbeat_loop` in app/tasks/celery_tasks/document_tasks.py.
Cancelled via heartbeat_task.cancel() when the indexing call returns
(success or failure). If the worker dies, the coroutine dies with it
and the Redis key expires naturally on its TTL.
"""
key = _get_heartbeat_key(notification_id)
try:
while True:
await asyncio.sleep(HEARTBEAT_REFRESH_INTERVAL)
try:
get_heartbeat_redis_client().setex(key, HEARTBEAT_TTL_SECONDS, "alive")
except Exception as e:
logger.warning(
f"Failed to refresh Redis heartbeat for notification "
f"{notification_id}: {e}"
)
except asyncio.CancelledError:
pass # Normal cancellation when the indexing task completes
def get_heartbeat_redis_client() -> redis.Redis:
@ -1028,25 +1058,6 @@ async def index_connector_content(
)
response_message = "Web page indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.OBSIDIAN_CONNECTOR:
from app.config import config as app_config
from app.tasks.celery_tasks.connector_tasks import index_obsidian_vault_task
# Obsidian connector only available in self-hosted mode
if not app_config.is_self_hosted():
raise HTTPException(
status_code=400,
detail="Obsidian connector is only available in self-hosted mode",
)
logger.info(
f"Triggering Obsidian vault indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_obsidian_vault_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Obsidian vault indexing started in the background."
elif (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
@ -1284,6 +1295,7 @@ async def _run_indexing_with_notifications(
notification = None
connector_lock_acquired = False
heartbeat_task: asyncio.Task | None = None
# Track indexed count for retry notifications and heartbeat
current_indexed_count = 0
@ -1329,6 +1341,16 @@ async def _run_indexing_with_notifications(
except Exception as e:
logger.warning(f"Failed to set initial Redis heartbeat: {e}")
# Start a background coroutine that refreshes the
# heartbeat every HEARTBEAT_REFRESH_INTERVAL seconds.
# Without this the cleanup_stale_indexing_notifications
# task can mark the doc failed when on_heartbeat_callback
# doesn't fire — for example during the GitHub
# connector's Phase 1 gitingest blocking call (#1295).
heartbeat_task = asyncio.create_task(
_run_indexing_heartbeat_loop(notification.id)
)
# Update notification to fetching stage
if notification:
await NotificationService.connector_indexing.notify_indexing_progress(
@ -1619,6 +1641,13 @@ async def _run_indexing_with_notifications(
except Exception as notif_error:
logger.error(f"Failed to update notification: {notif_error!s}")
finally:
# Stop the background heartbeat refresher BEFORE deleting the
# Redis key, so the loop cannot race and re-create the key
# after we delete it.
if heartbeat_task is not None:
heartbeat_task.cancel()
with suppress(Exception):
await asyncio.gather(heartbeat_task, return_exceptions=True)
# Clean up Redis heartbeat key when task completes (success or failure)
if notification:
try:
@ -2501,59 +2530,6 @@ async def run_bookstack_indexing(
)
# Add new helper functions for Obsidian indexing
async def run_obsidian_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run Obsidian indexing with its own database session."""
logger.info(
f"Background task started: Indexing Obsidian connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_obsidian_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing Obsidian connector {connector_id}")
async def run_obsidian_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Obsidian vault indexing.
Args:
session: Database session
connector_id: ID of the Obsidian connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_obsidian_vault
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_obsidian_vault,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
async def run_composio_indexing_with_new_session(
connector_id: int,
search_space_id: int,

View file

@ -1,59 +0,0 @@
"""
Obsidian Connector Credentials Schema.
Obsidian is a local-first note-taking app that stores notes as markdown files.
This connector supports indexing from local file system (self-hosted only).
"""
from pydantic import BaseModel, field_validator
class ObsidianAuthCredentialsBase(BaseModel):
"""
Credentials/configuration for the Obsidian connector.
Since Obsidian vaults are local directories, this schema primarily
holds the vault path and configuration options rather than API tokens.
"""
vault_path: str
vault_name: str | None = None
exclude_folders: list[str] | None = None
include_attachments: bool = False
@field_validator("vault_path")
@classmethod
def validate_vault_path(cls, v: str) -> str:
"""Ensure vault path is provided and stripped of whitespace."""
if not v or not v.strip():
raise ValueError("Vault path is required")
return v.strip()
@field_validator("exclude_folders", mode="before")
@classmethod
def parse_exclude_folders(cls, v):
"""Parse exclude_folders from string if needed."""
if v is None:
return [".trash", ".obsidian", "templates"]
if isinstance(v, str):
return [f.strip() for f in v.split(",") if f.strip()]
return v
def to_dict(self) -> dict:
"""Convert credentials to dictionary for storage."""
return {
"vault_path": self.vault_path,
"vault_name": self.vault_name,
"exclude_folders": self.exclude_folders,
"include_attachments": self.include_attachments,
}
@classmethod
def from_dict(cls, data: dict) -> "ObsidianAuthCredentialsBase":
"""Create credentials from dictionary."""
return cls(
vault_path=data.get("vault_path", ""),
vault_name=data.get("vault_name"),
exclude_folders=data.get("exclude_folders"),
include_attachments=data.get("include_attachments", False),
)

View file

@ -0,0 +1,234 @@
"""Wire schemas spoken between the SurfSense Obsidian plugin and the backend.
All schemas inherit ``extra='ignore'`` from :class:`_PluginBase` so additive
field changes never break either side; hard breaks live behind a new URL
prefix (``/api/v2/...``).
"""
from __future__ import annotations
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator
_PLUGIN_MODEL_CONFIG = ConfigDict(extra="ignore")
# Source of truth for the attachment whitelist. Mirrors MIME_BY_EXTENSION in
# surfsense_obsidian/src/sync-engine.ts — keep in sync.
ATTACHMENT_MIME_TYPES: dict[str, str] = {
"pdf": "application/pdf",
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"svg": "image/svg+xml",
"txt": "text/plain",
}
ALLOWED_ATTACHMENT_EXTENSIONS: frozenset[str] = frozenset(ATTACHMENT_MIME_TYPES)
class _PluginBase(BaseModel):
"""Base schema carrying the shared forward-compatibility config."""
model_config = _PLUGIN_MODEL_CONFIG
class HeadingRef(_PluginBase):
"""One markdown heading extracted from Obsidian metadata cache."""
heading: str
level: int = Field(ge=1, le=6)
class NotePayload(_PluginBase):
"""One Obsidian note as pushed by the plugin (the source of truth)."""
vault_id: str = Field(
..., description="Stable plugin-generated UUID for this vault"
)
path: str = Field(..., description="Vault-relative path, e.g. 'notes/foo.md'")
name: str = Field(..., description="File stem (no extension)")
extension: str = Field(
default="md", description="File extension without leading dot"
)
content: str = Field(default="", description="Raw markdown body (post-frontmatter)")
frontmatter: dict[str, Any] = Field(default_factory=dict)
tags: list[str] = Field(default_factory=list)
headings: list[HeadingRef] = Field(default_factory=list)
resolved_links: list[str] = Field(default_factory=list)
unresolved_links: list[str] = Field(default_factory=list)
embeds: list[str] = Field(default_factory=list)
aliases: list[str] = Field(default_factory=list)
content_hash: str = Field(
..., description="Plugin-computed SHA-256 of the raw content"
)
is_binary: bool = Field(
default=False,
description=(
"True when payload represents a non-markdown attachment. "
"If set, the plugin may include binary_base64 for ETL extraction."
),
)
binary_base64: str | None = Field(
default=None,
description=(
"Base64-encoded raw file bytes for binary attachments. "
"Used by the backend ETL pipeline."
),
)
mime_type: str | None = Field(
default=None,
description="Optional MIME type hint for binary attachments.",
)
size: int | None = Field(
default=None,
ge=0,
description="Byte size of the local file (mtime+size short-circuit signal). Optional for forward compatibility.",
)
mtime: datetime
ctime: datetime
@model_validator(mode="after")
def _enforce_binary_invariants(self) -> NotePayload:
if self.is_binary:
if not self.binary_base64:
raise ValueError("binary_base64 is required when is_binary is True")
if not self.mime_type:
raise ValueError("mime_type is required when is_binary is True")
elif self.binary_base64 is not None or self.mime_type is not None:
raise ValueError(
"binary_base64 and mime_type must be omitted when is_binary is False",
)
return self
class SyncBatchRequest(_PluginBase):
"""Batch upsert; plugin sends 10-20 notes per request."""
vault_id: str
notes: list[NotePayload] = Field(default_factory=list, max_length=100)
class RenameItem(_PluginBase):
old_path: str
new_path: str
class RenameBatchRequest(_PluginBase):
vault_id: str
renames: list[RenameItem] = Field(default_factory=list, max_length=200)
class DeleteBatchRequest(_PluginBase):
vault_id: str
paths: list[str] = Field(default_factory=list, max_length=500)
class ManifestEntry(_PluginBase):
hash: str
mtime: datetime
size: int | None = Field(
default=None,
description="Byte size last seen by the server. Enables mtime+size short-circuit; absent when not yet recorded.",
)
class ManifestResponse(_PluginBase):
"""Path-keyed manifest of every non-deleted note for a vault."""
vault_id: str
items: dict[str, ManifestEntry] = Field(default_factory=dict)
class ConnectRequest(_PluginBase):
"""Vault registration / heartbeat. Replayed on every plugin onload."""
vault_id: str
vault_name: str
search_space_id: int
vault_fingerprint: str = Field(
...,
description=(
"Deterministic SHA-256 over the sorted markdown paths in the vault "
"(plus vault_name). Same vault content on any device produces the "
"same value; the server uses it to dedup connectors across devices."
),
)
class ConnectResponse(_PluginBase):
"""Carries the same handshake fields as ``HealthResponse`` so the plugin
learns the contract without a separate ``GET /health`` round-trip."""
connector_id: int
vault_id: str
search_space_id: int
capabilities: list[str]
server_time_utc: datetime
class HealthResponse(_PluginBase):
"""API contract handshake. ``capabilities`` is additive-only string list."""
capabilities: list[str]
server_time_utc: datetime
# Per-item batch ack schemas — wire shape is load-bearing for the plugin
# queue (see api-client.ts / sync-engine.ts:processBatch).
class SyncAckItem(_PluginBase):
path: str
status: Literal["ok", "queued", "error"]
document_id: int | None = None
error: str | None = None
class SyncAck(_PluginBase):
vault_id: str
indexed: int
failed: int
items: list[SyncAckItem] = Field(default_factory=list)
class RenameAckItem(_PluginBase):
old_path: str
new_path: str
# ``missing`` is treated as success client-side (end state reached).
status: Literal["ok", "error", "missing"]
document_id: int | None = None
error: str | None = None
class RenameAck(_PluginBase):
vault_id: str
renamed: int
missing: int
items: list[RenameAckItem] = Field(default_factory=list)
class DeleteAckItem(_PluginBase):
path: str
status: Literal["ok", "error", "missing"]
error: str | None = None
class DeleteAck(_PluginBase):
vault_id: str
deleted: int
missing: int
items: list[DeleteAckItem] = Field(default_factory=list)
class StatsResponse(_PluginBase):
"""Backs the Obsidian connector tile in the web UI."""
vault_id: str
files_synced: int
last_sync_at: datetime | None = None

View file

@ -0,0 +1,616 @@
"""
Obsidian plugin indexer service.
Bridges the SurfSense Obsidian plugin's HTTP payloads
(see ``app/schemas/obsidian_plugin.py``) into the shared
``IndexingPipelineService``.
Responsibilities:
- ``upsert_note`` push one note through the indexing pipeline; respects
unchanged content (skip) and version-snapshots existing rows before
rewrite.
- ``rename_note`` rewrite path-derived fields (path metadata,
``unique_identifier_hash``, ``source_url``) without re-indexing content.
- ``delete_note`` soft delete with a tombstone in ``document_metadata``
so reconciliation can distinguish "user explicitly killed this in the UI"
from "plugin hasn't synced yet".
- ``get_manifest`` return ``{path: {hash, mtime, size}}`` for every
non-deleted note belonging to a vault, used by the plugin's reconcile
pass on ``onload``.
Design notes
------------
The plugin's content hash and the backend's ``content_hash`` are computed
differently (plugin uses raw SHA-256 of the markdown body; backend salts
with ``search_space_id``). We persist the plugin's hash in
``document_metadata['plugin_content_hash']`` so the manifest endpoint can
return what the plugin sent that's the only number the plugin can
compare without re-downloading content.
"""
from __future__ import annotations
import base64
import contextlib
import logging
import os
import tempfile
from datetime import UTC, datetime
from typing import Any
from urllib.parse import quote
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import (
Document,
DocumentStatus,
DocumentType,
SearchSourceConnector,
)
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.schemas.obsidian_plugin import (
ManifestEntry,
ManifestResponse,
NotePayload,
)
from app.utils.document_converters import generate_unique_identifier_hash
from app.utils.document_versioning import create_version_snapshot
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _vault_path_unique_id(vault_id: str, path: str) -> str:
"""Stable identifier for a note. Vault-scoped so the same path under two
different vaults doesn't collide."""
return f"{vault_id}:{path}"
def _build_source_url(vault_name: str, path: str) -> str:
"""Build the ``obsidian://`` deep link for the web UI's "Open in Obsidian"
button. Both segments are URL-encoded because vault names and paths can
contain spaces, ``#``, ``?``, etc.
"""
return (
"obsidian://open"
f"?vault={quote(vault_name, safe='')}"
f"&file={quote(path, safe='')}"
)
def _build_metadata(
payload: NotePayload,
*,
vault_name: str,
connector_id: int,
extra: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Flatten the rich plugin payload into the JSONB ``document_metadata``
column. Keys here are what the chat UI / search UI surface to users.
"""
meta: dict[str, Any] = {
"source": "plugin",
"vault_id": payload.vault_id,
"vault_name": vault_name,
"file_path": payload.path,
"file_name": payload.name,
"extension": payload.extension,
"frontmatter": payload.frontmatter,
"tags": payload.tags,
"headings": [h.model_dump() for h in payload.headings],
"outgoing_links": payload.resolved_links,
"unresolved_links": payload.unresolved_links,
"embeds": payload.embeds,
"aliases": payload.aliases,
"plugin_content_hash": payload.content_hash,
"plugin_file_size": payload.size,
"mtime": payload.mtime.isoformat(),
"ctime": payload.ctime.isoformat(),
"connector_id": connector_id,
"url": _build_source_url(vault_name, payload.path),
}
if payload.is_binary:
meta["is_binary"] = True
meta["mime_type"] = payload.mime_type
if extra:
meta.update(extra)
return meta
def _build_document_string(
payload: NotePayload, vault_name: str, *, content_override: str | None = None
) -> str:
"""Compose the indexable string the pipeline embeds and chunks.
Mirrors the legacy obsidian indexer's METADATA + CONTENT framing so
existing search relevance heuristics keep working unchanged.
"""
tags_line = ", ".join(payload.tags) if payload.tags else "None"
links_line = ", ".join(payload.resolved_links) if payload.resolved_links else "None"
body = payload.content if content_override is None else content_override
return (
"<METADATA>\n"
f"Title: {payload.name}\n"
f"Vault: {vault_name}\n"
f"Path: {payload.path}\n"
f"Tags: {tags_line}\n"
f"Links to: {links_line}\n"
"</METADATA>\n\n"
"<CONTENT>\n"
f"{body}\n"
"</CONTENT>\n"
)
async def _extract_binary_attachment_markdown(
payload: NotePayload, *, vision_llm
) -> tuple[str, dict[str, Any]]:
try:
raw_bytes = base64.b64decode(payload.binary_base64, validate=True)
except Exception:
logger.warning("obsidian attachment payload had invalid base64: %s", payload.path)
return "", {"attachment_extraction_status": "invalid_binary_payload"}
suffix = f".{payload.extension.lstrip('.')}"
temp_path: str | None = None
filename = payload.path.rsplit("/", 1)[-1] or payload.name
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(raw_bytes)
temp_path = tmp.name
result = await _run_etl_extract(
file_path=temp_path,
filename=filename,
vision_llm=vision_llm,
)
metadata: dict[str, Any] = {
"attachment_extraction_status": "ok",
"attachment_etl_service": result.etl_service,
"attachment_content_type": result.content_type,
}
return result.markdown_content, metadata
except Exception as exc:
logger.warning(
"obsidian attachment ETL failed for %s: %s", payload.path, exc, exc_info=True
)
return "", {
"attachment_extraction_status": "etl_failed",
"attachment_extraction_error": str(exc)[:300],
}
finally:
if temp_path and os.path.exists(temp_path):
with contextlib.suppress(Exception):
os.unlink(temp_path)
async def _run_etl_extract(*, file_path: str, filename: str, vision_llm):
"""Lazy-load ETL dependencies to avoid module-import cycles."""
from app.etl_pipeline.etl_document import EtlRequest
from app.etl_pipeline.etl_pipeline_service import EtlPipelineService
return await EtlPipelineService(vision_llm=vision_llm).extract(
EtlRequest(file_path=file_path, filename=filename)
)
def _is_image_attachment(payload: NotePayload) -> bool:
ext = payload.extension.lower().lstrip(".")
return ext in {"png", "jpg", "jpeg", "gif", "webp", "svg"}
async def _resolve_attachment_vision_llm(
session: AsyncSession,
*,
connector: SearchSourceConnector,
search_space_id: int,
payload: NotePayload,
):
"""Match connector indexers: only fetch vision LLM for image attachments
when the connector has vision indexing enabled."""
if not payload.is_binary:
return None
if not _is_image_attachment(payload):
return None
if not getattr(connector, "enable_vision_llm", False):
return None
from app.services.llm_service import get_vision_llm
return await get_vision_llm(session, search_space_id)
async def _resolve_summary_llm(
session: AsyncSession, *, user_id: str, search_space_id: int, should_summarize: bool
):
"""Fetch summary LLM only when indexing summary is enabled."""
if not should_summarize:
return None
from app.services.llm_service import get_user_long_context_llm
return await get_user_long_context_llm(session, user_id, search_space_id)
def _require_extracted_attachment_content(
*, content: str, etl_meta: dict[str, Any], path: str
) -> str:
extracted = content.strip()
if extracted:
return extracted
status = etl_meta.get("attachment_extraction_status", "unknown")
reason = etl_meta.get("attachment_extraction_error")
if reason:
raise RuntimeError(
f"Attachment extraction failed for {path} ({status}): {reason}"
)
raise RuntimeError(f"Attachment extraction failed for {path} ({status})")
async def _find_existing_document(
session: AsyncSession,
*,
search_space_id: int,
vault_id: str,
path: str,
) -> Document | None:
unique_id = _vault_path_unique_id(vault_id, path)
uid_hash = generate_unique_identifier_hash(
DocumentType.OBSIDIAN_CONNECTOR,
unique_id,
search_space_id,
)
result = await session.execute(
select(Document).where(Document.unique_identifier_hash == uid_hash)
)
return result.scalars().first()
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def upsert_note(
session: AsyncSession,
*,
connector: SearchSourceConnector,
payload: NotePayload,
user_id: str,
) -> Document:
"""Index or refresh a single note pushed by the plugin.
Returns the resulting ``Document`` (whether newly created, updated, or
a skip-because-unchanged hit).
"""
vault_name: str = (connector.config or {}).get("vault_name") or "Vault"
search_space_id = connector.search_space_id
existing = await _find_existing_document(
session,
search_space_id=search_space_id,
vault_id=payload.vault_id,
path=payload.path,
)
plugin_hash = payload.content_hash
if existing is not None:
existing_meta = existing.document_metadata or {}
was_tombstoned = bool(existing_meta.get("deleted_at"))
if (
not was_tombstoned
and existing_meta.get("plugin_content_hash") == plugin_hash
and DocumentStatus.is_state(existing.status, DocumentStatus.READY)
):
return existing
try:
await create_version_snapshot(session, existing)
except Exception:
logger.debug(
"version snapshot failed for obsidian doc %s",
existing.id,
exc_info=True,
)
content_for_index = payload.content
extra_meta: dict[str, Any] = {}
vision_llm = None
if payload.is_binary:
vision_llm = await _resolve_attachment_vision_llm(
session,
connector=connector,
search_space_id=search_space_id,
payload=payload,
)
content_for_index, etl_meta = await _extract_binary_attachment_markdown(
payload, vision_llm=vision_llm
)
extra_meta.update(etl_meta)
# Strict KB behavior: do not index metadata-only attachments.
content_for_index = _require_extracted_attachment_content(
content=content_for_index,
etl_meta=etl_meta,
path=payload.path,
)
llm = await _resolve_summary_llm(
session,
user_id=str(user_id),
search_space_id=search_space_id,
should_summarize=connector.enable_summary,
)
document_string = _build_document_string(
payload, vault_name, content_override=content_for_index
)
metadata = _build_metadata(
payload,
vault_name=vault_name,
connector_id=connector.id,
extra=extra_meta,
)
connector_doc = ConnectorDocument(
title=payload.name,
source_markdown=document_string,
unique_id=_vault_path_unique_id(payload.vault_id, payload.path),
document_type=DocumentType.OBSIDIAN_CONNECTOR,
search_space_id=search_space_id,
connector_id=connector.id,
created_by_id=str(user_id),
should_summarize=connector.enable_summary,
fallback_summary=f"Obsidian Note: {payload.name}\n\n{content_for_index}",
metadata=metadata,
)
pipeline = IndexingPipelineService(session)
prepared = await pipeline.prepare_for_indexing([connector_doc])
if not prepared:
if existing is not None:
return existing
raise RuntimeError(f"Indexing pipeline rejected obsidian note {payload.path}")
document = prepared[0]
return await pipeline.index(document, connector_doc, llm)
async def rename_note(
session: AsyncSession,
*,
connector: SearchSourceConnector,
old_path: str,
new_path: str,
vault_id: str,
) -> Document | None:
"""Rewrite path-derived columns without re-indexing content.
Returns the updated document, or ``None`` if no row matched the
``old_path`` (this happens when the plugin is renaming a file that was
never synced safe to ignore, the next ``sync`` will create it under
the new path).
"""
vault_name: str = (connector.config or {}).get("vault_name") or "Vault"
search_space_id = connector.search_space_id
existing = await _find_existing_document(
session,
search_space_id=search_space_id,
vault_id=vault_id,
path=old_path,
)
if existing is None:
return None
new_unique_id = _vault_path_unique_id(vault_id, new_path)
new_uid_hash = generate_unique_identifier_hash(
DocumentType.OBSIDIAN_CONNECTOR,
new_unique_id,
search_space_id,
)
collision = await session.execute(
select(Document).where(
and_(
Document.unique_identifier_hash == new_uid_hash,
Document.id != existing.id,
)
)
)
collision_row = collision.scalars().first()
if collision_row is not None:
logger.warning(
"obsidian rename target already exists "
"(vault=%s old=%s new=%s); skipping rename so the next /sync "
"can resolve the conflict via content_hash",
vault_id,
old_path,
new_path,
)
return existing
new_filename = new_path.rsplit("/", 1)[-1]
new_stem = new_filename.rsplit(".", 1)[0] if "." in new_filename else new_filename
existing.unique_identifier_hash = new_uid_hash
existing.title = new_stem
meta = dict(existing.document_metadata or {})
meta["file_path"] = new_path
meta["file_name"] = new_stem
meta["url"] = _build_source_url(vault_name, new_path)
existing.document_metadata = meta
existing.updated_at = datetime.now(UTC)
await session.commit()
return existing
async def delete_note(
session: AsyncSession,
*,
connector: SearchSourceConnector,
vault_id: str,
path: str,
) -> bool:
"""Soft-delete via tombstone in ``document_metadata``.
The row is *not* removed and chunks are *not* dropped, so existing
citations in chat threads remain resolvable. The manifest endpoint
filters tombstoned rows out, so the plugin's reconcile pass will not
see this path and won't try to "resurrect" a note the user deleted in
the SurfSense UI.
Returns True if a row was tombstoned, False if no matching row existed.
"""
existing = await _find_existing_document(
session,
search_space_id=connector.search_space_id,
vault_id=vault_id,
path=path,
)
if existing is None:
return False
meta = dict(existing.document_metadata or {})
if meta.get("deleted_at"):
return True
meta["deleted_at"] = datetime.now(UTC).isoformat()
meta["deleted_by_source"] = "plugin"
existing.document_metadata = meta
existing.updated_at = datetime.now(UTC)
await session.commit()
return True
async def merge_obsidian_connectors(
session: AsyncSession,
*,
source: SearchSourceConnector,
target: SearchSourceConnector,
) -> None:
"""Fold ``source``'s documents into ``target`` and delete ``source``.
Triggered when the fingerprint dedup detects two plugin connectors
pointing at the same vault (e.g. a mobile install raced with iCloud
hydration and got a partial fingerprint, then caught up). Path
collisions resolve in favour of ``target`` (the surviving row);
``source``'s duplicate documents are hard-deleted along with their
chunks via the ``cascade='all, delete-orphan'`` on ``Document.chunks``.
"""
if source.id == target.id:
return
target_vault_id = (target.config or {}).get("vault_id")
target_search_space_id = target.search_space_id
if not target_vault_id:
raise RuntimeError("merge target is missing vault_id")
target_paths_result = await session.execute(
select(Document).where(
and_(
Document.connector_id == target.id,
Document.document_type == DocumentType.OBSIDIAN_CONNECTOR,
)
)
)
target_paths: set[str] = set()
for doc in target_paths_result.scalars().all():
meta = doc.document_metadata or {}
path = meta.get("file_path")
if path:
target_paths.add(path)
source_docs_result = await session.execute(
select(Document).where(
and_(
Document.connector_id == source.id,
Document.document_type == DocumentType.OBSIDIAN_CONNECTOR,
)
)
)
for doc in source_docs_result.scalars().all():
meta = dict(doc.document_metadata or {})
path = meta.get("file_path")
if not path or path in target_paths:
await session.delete(doc)
continue
new_unique_id = _vault_path_unique_id(target_vault_id, path)
new_uid_hash = generate_unique_identifier_hash(
DocumentType.OBSIDIAN_CONNECTOR,
new_unique_id,
target_search_space_id,
)
meta["vault_id"] = target_vault_id
meta["connector_id"] = target.id
doc.document_metadata = meta
doc.connector_id = target.id
doc.search_space_id = target_search_space_id
doc.unique_identifier_hash = new_uid_hash
target_paths.add(path)
await session.flush()
await session.delete(source)
async def get_manifest(
session: AsyncSession,
*,
connector: SearchSourceConnector,
vault_id: str,
) -> ManifestResponse:
"""Return ``{path: {hash, mtime, size}}`` for every non-deleted note in
this vault.
The plugin compares this against its local vault on every ``onload`` to
catch up edits made while offline. Rows missing ``plugin_content_hash``
(e.g. tombstoned, or somehow indexed without going through this
service) are excluded so the plugin doesn't get confused by partial
data.
"""
result = await session.execute(
select(Document).where(
and_(
Document.search_space_id == connector.search_space_id,
Document.connector_id == connector.id,
Document.document_type == DocumentType.OBSIDIAN_CONNECTOR,
)
)
)
items: dict[str, ManifestEntry] = {}
for doc in result.scalars().all():
meta = doc.document_metadata or {}
if meta.get("deleted_at"):
continue
if meta.get("vault_id") != vault_id:
continue
path = meta.get("file_path")
plugin_hash = meta.get("plugin_content_hash")
mtime_raw = meta.get("mtime")
if not path or not plugin_hash or not mtime_raw:
continue
try:
mtime = datetime.fromisoformat(mtime_raw)
except ValueError:
continue
size_raw = meta.get("plugin_file_size")
size = int(size_raw) if isinstance(size_raw, int) else None
items[path] = ManifestEntry(hash=plugin_hash, mtime=mtime, size=size)
return ManifestResponse(vault_id=vault_id, items=items)

View file

@ -536,49 +536,6 @@ async def _index_bookstack_pages(
)
@celery_app.task(name="index_obsidian_vault", bind=True)
def index_obsidian_vault_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Obsidian vault notes."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_obsidian_vault(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_obsidian_vault(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Obsidian vault with new session."""
from app.routes.search_source_connectors_routes import (
run_obsidian_indexing,
)
async with get_celery_session_maker()() as session:
await run_obsidian_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_composio_connector", bind=True)
def index_composio_connector_task(
self,

View file

@ -0,0 +1,59 @@
"""Celery tasks for Obsidian plugin background processing."""
from __future__ import annotations
import asyncio
import logging
from app.celery_app import celery_app
from app.db import SearchSourceConnector
from app.schemas.obsidian_plugin import NotePayload
from app.services.obsidian_plugin_indexer import upsert_note
from app.tasks.celery_tasks import get_celery_session_maker
logger = logging.getLogger(__name__)
@celery_app.task(name="index_obsidian_attachment", bind=True)
def index_obsidian_attachment_task(
self,
connector_id: int,
payload_data: dict,
user_id: str,
) -> None:
"""Process one Obsidian non-markdown attachment asynchronously."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_obsidian_attachment(
connector_id=connector_id,
payload_data=payload_data,
user_id=user_id,
)
)
finally:
loop.close()
async def _index_obsidian_attachment(
*,
connector_id: int,
payload_data: dict,
user_id: str,
) -> None:
async with get_celery_session_maker()() as session:
connector = await session.get(SearchSourceConnector, connector_id)
if connector is None:
logger.warning(
"obsidian attachment task skipped: connector %s not found", connector_id
)
return
payload = NotePayload.model_validate(payload_data)
await upsert_note(
session,
connector=connector,
payload=payload,
user_id=user_id,
)

View file

@ -14,18 +14,16 @@ from .google_calendar_indexer import index_google_calendar_events
from .google_drive_indexer import index_google_drive_files
from .google_gmail_indexer import index_google_gmail_messages
from .notion_indexer import index_notion_pages
from .obsidian_indexer import index_obsidian_vault
from .webcrawler_indexer import index_crawled_urls
__all__ = [
"index_bookstack_pages",
"index_confluence_pages",
"index_crawled_urls",
"index_elasticsearch_documents",
"index_github_repos",
"index_google_calendar_events",
"index_google_drive_files",
"index_google_gmail_messages",
"index_notion_pages",
"index_obsidian_vault",
"index_crawled_urls",
]

View file

@ -1,676 +0,0 @@
"""
Obsidian connector indexer.
Indexes markdown notes from a local Obsidian vault.
This connector is only available in self-hosted mode.
Implements 2-phase document status updates for real-time UI feedback:
- Phase 1: Create all documents with 'pending' status (visible in UI immediately)
- Phase 2: Process each document: pending processing ready/failed
"""
import os
import re
import time
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from pathlib import Path
import yaml
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.document_converters import (
create_document_chunks,
embed_text,
generate_content_hash,
generate_document_summary,
generate_unique_identifier_hash,
)
from .base import (
build_document_metadata_string,
check_document_by_unique_identifier,
check_duplicate_document_by_hash,
get_connector_by_id,
get_current_timestamp,
logger,
safe_set_chunks,
update_connector_last_indexed,
)
# Type hint for heartbeat callback
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
# Heartbeat interval in seconds
HEARTBEAT_INTERVAL_SECONDS = 30
def parse_frontmatter(content: str) -> tuple[dict | None, str]:
"""
Parse YAML frontmatter from markdown content.
Args:
content: The full markdown content
Returns:
Tuple of (frontmatter dict or None, content without frontmatter)
"""
if not content.startswith("---"):
return None, content
# Find the closing ---
end_match = re.search(r"\n---\n", content[3:])
if not end_match:
return None, content
frontmatter_str = content[3 : end_match.start() + 3]
remaining_content = content[end_match.end() + 3 :]
try:
frontmatter = yaml.safe_load(frontmatter_str)
return frontmatter, remaining_content.strip()
except yaml.YAMLError:
return None, content
def extract_wiki_links(content: str) -> list[str]:
"""
Extract [[wiki-style links]] from content.
Args:
content: Markdown content
Returns:
List of linked note names
"""
# Match [[link]] or [[link|alias]]
pattern = r"\[\[([^\]|]+)(?:\|[^\]]+)?\]\]"
matches = re.findall(pattern, content)
return list(set(matches))
def extract_tags(content: str) -> list[str]:
"""
Extract #tags from content (both inline and frontmatter).
Args:
content: Markdown content
Returns:
List of tags (without # prefix)
"""
# Match #tag but not ## headers
pattern = r"(?<!\S)#([a-zA-Z][a-zA-Z0-9_/-]*)"
matches = re.findall(pattern, content)
return list(set(matches))
def scan_vault(
vault_path: str,
exclude_folders: list[str] | None = None,
) -> list[dict]:
"""
Scan an Obsidian vault for markdown files.
Args:
vault_path: Path to the Obsidian vault
exclude_folders: List of folder names to exclude
Returns:
List of file info dicts with path, name, modified time
"""
if exclude_folders is None:
exclude_folders = [".trash", ".obsidian", "templates"]
vault = Path(vault_path)
if not vault.exists():
raise ValueError(f"Vault path does not exist: {vault_path}")
files = []
for md_file in vault.rglob("*.md"):
# Check if file is in an excluded folder
relative_path = md_file.relative_to(vault)
parts = relative_path.parts
if any(excluded in parts for excluded in exclude_folders):
continue
try:
stat = md_file.stat()
files.append(
{
"path": str(md_file),
"relative_path": str(relative_path),
"name": md_file.stem,
"modified_at": datetime.fromtimestamp(stat.st_mtime, tz=UTC),
"created_at": datetime.fromtimestamp(stat.st_ctime, tz=UTC),
"size": stat.st_size,
}
)
except OSError as e:
logger.warning(f"Could not stat file {md_file}: {e}")
return files
async def index_obsidian_vault(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str | None = None,
end_date: str | None = None,
update_last_indexed: bool = True,
on_heartbeat_callback: HeartbeatCallbackType | None = None,
) -> tuple[int, str | None]:
"""
Index notes from a local Obsidian vault.
This indexer is only available in self-hosted mode as it requires
direct file system access to the user's Obsidian vault.
Args:
session: Database session
connector_id: ID of the Obsidian connector
search_space_id: ID of the search space to store documents in
user_id: ID of the user
start_date: Start date for filtering (YYYY-MM-DD format) - optional
end_date: End date for filtering (YYYY-MM-DD format) - optional
update_last_indexed: Whether to update the last_indexed_at timestamp
on_heartbeat_callback: Optional callback to update notification during long-running indexing.
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
task_logger = TaskLoggingService(session, search_space_id)
# Check if self-hosted mode
if not config.is_self_hosted():
return 0, "Obsidian connector is only available in self-hosted mode"
# Log task start
log_entry = await task_logger.log_task_start(
task_name="obsidian_vault_indexing",
source="connector_indexing_task",
message=f"Starting Obsidian vault indexing for connector {connector_id}",
metadata={
"connector_id": connector_id,
"user_id": str(user_id),
"start_date": start_date,
"end_date": end_date,
},
)
try:
# Get the connector
await task_logger.log_task_progress(
log_entry,
f"Retrieving Obsidian connector {connector_id} from database",
{"stage": "connector_retrieval"},
)
connector = await get_connector_by_id(
session, connector_id, SearchSourceConnectorType.OBSIDIAN_CONNECTOR
)
if not connector:
await task_logger.log_task_failure(
log_entry,
f"Connector with ID {connector_id} not found or is not an Obsidian connector",
"Connector not found",
{"error_type": "ConnectorNotFound"},
)
return (
0,
f"Connector with ID {connector_id} not found or is not an Obsidian connector",
)
# Get vault path from connector config
vault_path = connector.config.get("vault_path")
if not vault_path:
await task_logger.log_task_failure(
log_entry,
"Vault path not configured for this connector",
"Missing vault path",
{"error_type": "MissingVaultPath"},
)
return 0, "Vault path not configured for this connector"
# Validate vault path exists
if not os.path.exists(vault_path):
await task_logger.log_task_failure(
log_entry,
f"Vault path does not exist: {vault_path}",
"Vault path not found",
{"error_type": "VaultNotFound", "vault_path": vault_path},
)
return 0, f"Vault path does not exist: {vault_path}"
# Get configuration options
exclude_folders = connector.config.get(
"exclude_folders", [".trash", ".obsidian", "templates"]
)
vault_name = connector.config.get("vault_name") or os.path.basename(vault_path)
await task_logger.log_task_progress(
log_entry,
f"Scanning Obsidian vault: {vault_name}",
{"stage": "vault_scan", "vault_path": vault_path},
)
# Scan vault for markdown files
try:
files = scan_vault(vault_path, exclude_folders)
except Exception as e:
await task_logger.log_task_failure(
log_entry,
f"Failed to scan vault: {e}",
"Vault scan error",
{"error_type": "VaultScanError"},
)
return 0, f"Failed to scan vault: {e}"
logger.info(f"Found {len(files)} markdown files in vault")
await task_logger.log_task_progress(
log_entry,
f"Found {len(files)} markdown files to process",
{"stage": "files_discovered", "file_count": len(files)},
)
# Filter by date if provided (handle "undefined" string from frontend)
# Also handle inverted dates (start > end) by skipping filtering
start_dt = None
end_dt = None
if start_date and start_date != "undefined":
start_dt = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=UTC)
if end_date and end_date != "undefined":
# Make end_date inclusive (end of day)
end_dt = datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=UTC)
end_dt = end_dt.replace(hour=23, minute=59, second=59)
# Only apply date filtering if dates are valid and in correct order
if start_dt and end_dt and start_dt > end_dt:
logger.warning(
f"start_date ({start_date}) is after end_date ({end_date}), skipping date filter"
)
else:
if start_dt:
files = [f for f in files if f["modified_at"] >= start_dt]
logger.info(
f"After start_date filter ({start_date}): {len(files)} files"
)
if end_dt:
files = [f for f in files if f["modified_at"] <= end_dt]
logger.info(f"After end_date filter ({end_date}): {len(files)} files")
logger.info(f"Processing {len(files)} files after date filtering")
indexed_count = 0
skipped_count = 0
failed_count = 0
duplicate_content_count = 0
# Heartbeat tracking - update notification periodically to prevent appearing stuck
last_heartbeat_time = time.time()
# =======================================================================
# PHASE 1: Analyze all files, create pending documents
# This makes ALL documents visible in the UI immediately with pending status
# =======================================================================
files_to_process = [] # List of dicts with document and file data
new_documents_created = False
for file_info in files:
try:
file_path = file_info["path"]
relative_path = file_info["relative_path"]
# Read file content
try:
with open(file_path, encoding="utf-8") as f:
content = f.read()
except UnicodeDecodeError:
logger.warning(f"Could not decode file {file_path}, skipping")
skipped_count += 1
continue
if not content.strip():
logger.debug(f"Empty file {file_path}, skipping")
skipped_count += 1
continue
# Parse frontmatter and extract metadata
frontmatter, body_content = parse_frontmatter(content)
wiki_links = extract_wiki_links(content)
tags = extract_tags(content)
# Get title from frontmatter or filename
title = file_info["name"]
if frontmatter:
title = frontmatter.get("title", title)
# Also extract tags from frontmatter
fm_tags = frontmatter.get("tags", [])
if isinstance(fm_tags, list):
tags = list({*tags, *fm_tags})
elif isinstance(fm_tags, str):
tags = list({*tags, fm_tags})
# Generate unique identifier using vault name and relative path
unique_identifier = f"{vault_name}:{relative_path}"
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.OBSIDIAN_CONNECTOR,
unique_identifier,
search_space_id,
)
# Generate content hash
content_hash = generate_content_hash(content, search_space_id)
# Check for existing document
existing_document = await check_document_by_unique_identifier(
session, unique_identifier_hash
)
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
# Ensure status is ready (might have been stuck in processing/pending)
if not DocumentStatus.is_state(
existing_document.status, DocumentStatus.READY
):
existing_document.status = DocumentStatus.ready()
logger.debug(f"Note {title} unchanged, skipping")
skipped_count += 1
continue
# Queue existing document for update (will be set to processing in Phase 2)
files_to_process.append(
{
"document": existing_document,
"is_new": False,
"file_info": file_info,
"content": content,
"body_content": body_content,
"frontmatter": frontmatter,
"wiki_links": wiki_links,
"tags": tags,
"title": title,
"relative_path": relative_path,
"content_hash": content_hash,
"unique_identifier_hash": unique_identifier_hash,
}
)
continue
# Document doesn't exist by unique_identifier_hash
# Check if a document with the same content_hash exists (from another connector)
with session.no_autoflush:
duplicate_by_content = await check_duplicate_document_by_hash(
session, content_hash
)
if duplicate_by_content:
logger.info(
f"Obsidian note {title} already indexed by another connector "
f"(existing document ID: {duplicate_by_content.id}, "
f"type: {duplicate_by_content.document_type}). Skipping."
)
duplicate_content_count += 1
skipped_count += 1
continue
# Create new document with PENDING status (visible in UI immediately)
document = Document(
search_space_id=search_space_id,
title=title,
document_type=DocumentType.OBSIDIAN_CONNECTOR,
document_metadata={
"vault_name": vault_name,
"file_path": relative_path,
"connector_id": connector_id,
},
content="Pending...", # Placeholder until processed
content_hash=unique_identifier_hash, # Temporary unique value - updated when ready
unique_identifier_hash=unique_identifier_hash,
embedding=None,
chunks=[], # Empty at creation - safe for async
status=DocumentStatus.pending(), # Pending until processing starts
updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
)
session.add(document)
new_documents_created = True
files_to_process.append(
{
"document": document,
"is_new": True,
"file_info": file_info,
"content": content,
"body_content": body_content,
"frontmatter": frontmatter,
"wiki_links": wiki_links,
"tags": tags,
"title": title,
"relative_path": relative_path,
"content_hash": content_hash,
"unique_identifier_hash": unique_identifier_hash,
}
)
except Exception as e:
logger.exception(
f"Error in Phase 1 for file {file_info.get('path', 'unknown')}: {e}"
)
failed_count += 1
continue
# Commit all pending documents - they all appear in UI now
if new_documents_created:
logger.info(
f"Phase 1: Committing {len([f for f in files_to_process if f['is_new']])} pending documents"
)
await session.commit()
# =======================================================================
# PHASE 2: Process each document one by one
# Each document transitions: pending → processing → ready/failed
# =======================================================================
logger.info(f"Phase 2: Processing {len(files_to_process)} documents")
# Get LLM for summarization
long_context_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
for item in files_to_process:
# Send heartbeat periodically
if on_heartbeat_callback:
current_time = time.time()
if current_time - last_heartbeat_time >= HEARTBEAT_INTERVAL_SECONDS:
await on_heartbeat_callback(indexed_count)
last_heartbeat_time = current_time
document = item["document"]
try:
# Set to PROCESSING and commit - shows "processing" in UI for THIS document only
document.status = DocumentStatus.processing()
await session.commit()
# Extract data from item
title = item["title"]
relative_path = item["relative_path"]
content = item["content"]
body_content = item["body_content"]
frontmatter = item["frontmatter"]
wiki_links = item["wiki_links"]
tags = item["tags"]
content_hash = item["content_hash"]
file_info = item["file_info"]
# Build metadata
document_metadata = {
"vault_name": vault_name,
"file_path": relative_path,
"tags": tags,
"outgoing_links": wiki_links,
"frontmatter": frontmatter,
"modified_at": file_info["modified_at"].isoformat(),
"created_at": file_info["created_at"].isoformat(),
"word_count": len(body_content.split()),
}
# Build document content with metadata
metadata_sections = [
(
"METADATA",
[
f"Title: {title}",
f"Vault: {vault_name}",
f"Path: {relative_path}",
f"Tags: {', '.join(tags) if tags else 'None'}",
f"Links to: {', '.join(wiki_links) if wiki_links else 'None'}",
],
),
("CONTENT", [body_content]),
]
document_string = build_document_metadata_string(metadata_sections)
# Generate summary
summary_content = ""
if long_context_llm and connector.enable_summary:
summary_content, _ = await generate_document_summary(
document_string,
long_context_llm,
document_metadata,
)
# Generate embedding
embedding = embed_text(document_string)
# Add URL and summary to metadata
document_metadata["url"] = f"obsidian://{vault_name}/{relative_path}"
document_metadata["summary"] = summary_content
document_metadata["connector_id"] = connector_id
# Create chunks
chunks = await create_document_chunks(document_string)
# Update document to READY with actual content
document.title = title
document.content = document_string
document.content_hash = content_hash
document.embedding = embedding
document.document_metadata = document_metadata
await safe_set_chunks(session, document, chunks)
document.updated_at = get_current_timestamp()
document.status = DocumentStatus.ready()
indexed_count += 1
# Batch commit every 10 documents (for ready status updates)
if indexed_count % 10 == 0:
logger.info(
f"Committing batch: {indexed_count} Obsidian notes processed so far"
)
await session.commit()
except Exception as e:
logger.exception(
f"Error processing file {item.get('file_info', {}).get('path', 'unknown')}: {e}"
)
# Mark document as failed with reason (visible in UI)
try:
document.status = DocumentStatus.failed(str(e))
document.updated_at = get_current_timestamp()
except Exception as status_error:
logger.error(
f"Failed to update document status to failed: {status_error}"
)
failed_count += 1
continue
# CRITICAL: Always update timestamp (even if 0 documents indexed) so Zero syncs
await update_connector_last_indexed(session, connector, update_last_indexed)
# Final commit for any remaining documents not yet committed in batches
logger.info(f"Final commit: Total {indexed_count} Obsidian notes processed")
try:
await session.commit()
logger.info(
"Successfully committed all Obsidian document changes to database"
)
except Exception as e:
# Handle any remaining integrity errors gracefully (race conditions, etc.)
if (
"duplicate key value violates unique constraint" in str(e).lower()
or "uniqueviolationerror" in str(e).lower()
):
logger.warning(
f"Duplicate content_hash detected during final commit. "
f"This may occur if the same note was indexed by multiple connectors. "
f"Rolling back and continuing. Error: {e!s}"
)
await session.rollback()
# Don't fail the entire task - some documents may have been successfully indexed
else:
raise
# Build warning message if there were issues
warning_parts = []
if duplicate_content_count > 0:
warning_parts.append(f"{duplicate_content_count} duplicate")
if failed_count > 0:
warning_parts.append(f"{failed_count} failed")
warning_message = ", ".join(warning_parts) if warning_parts else None
total_processed = indexed_count
await task_logger.log_task_success(
log_entry,
f"Successfully completed Obsidian vault indexing for connector {connector_id}",
{
"notes_processed": total_processed,
"documents_indexed": indexed_count,
"documents_skipped": skipped_count,
"documents_failed": failed_count,
"duplicate_content_count": duplicate_content_count,
},
)
logger.info(
f"Obsidian vault indexing completed: {indexed_count} ready, "
f"{skipped_count} skipped, {failed_count} failed "
f"({duplicate_content_count} duplicate content)"
)
return total_processed, warning_message
except SQLAlchemyError as e:
logger.exception(f"Database error during Obsidian indexing: {e}")
await session.rollback()
await task_logger.log_task_failure(
log_entry,
f"Database error during Obsidian indexing: {e}",
"Database error",
{"error_type": "SQLAlchemyError"},
)
return 0, f"Database error: {e}"
except Exception as e:
logger.exception(f"Error during Obsidian indexing: {e}")
await task_logger.log_task_failure(
log_entry,
f"Error during Obsidian indexing: {e}",
"Unexpected error",
{"error_type": type(e).__name__},
)
return 0, str(e)

View file

@ -24,7 +24,6 @@ CONNECTOR_TASK_MAP = {
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents",
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls",
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages",
SearchSourceConnectorType.OBSIDIAN_CONNECTOR: "index_obsidian_vault",
}
@ -81,7 +80,6 @@ def create_periodic_schedule(
index_elasticsearch_documents_task,
index_github_repos_task,
index_notion_pages_task,
index_obsidian_vault_task,
)
task_map = {
@ -91,7 +89,6 @@ def create_periodic_schedule(
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task,
SearchSourceConnectorType.OBSIDIAN_CONNECTOR: index_obsidian_vault_task,
}
# Trigger the first run immediately

View file

@ -0,0 +1,625 @@
"""Integration tests for the Obsidian plugin HTTP wire contract.
Three concerns:
1. The /connect upsert really collapses concurrent first-time connects to
exactly one row. This locks the partial unique index from migration 129
to its purpose.
2. The fingerprint dedup path: a second device connecting with a fresh
``vault_id`` but the same ``vault_fingerprint`` adopts the existing
connector instead of creating a duplicate.
3. The end-to-end response shapes returned by /connect /sync /rename
/notes /manifest /stats match the schemas the plugin's TypeScript
decoders expect. Each renamed field is a contract change, and a smoke
pass like this is the cheapest way to catch a future drift before it
ships.
"""
from __future__ import annotations
import asyncio
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from sqlalchemy import func, select, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
User,
)
from app.routes.obsidian_plugin_routes import (
obsidian_connect,
obsidian_delete_notes,
obsidian_manifest,
obsidian_rename,
obsidian_stats,
obsidian_sync,
)
from app.schemas.obsidian_plugin import (
ConnectRequest,
DeleteAck,
DeleteBatchRequest,
HeadingRef,
ManifestResponse,
NotePayload,
RenameAck,
RenameBatchRequest,
RenameItem,
StatsResponse,
SyncAck,
SyncBatchRequest,
)
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_note_payload(vault_id: str, path: str, content_hash: str) -> NotePayload:
"""Minimal NotePayload that the schema accepts; the indexer is mocked
out so the values don't have to round-trip through the real pipeline."""
now = datetime.now(UTC)
return NotePayload(
vault_id=vault_id,
path=path,
name=path.rsplit("/", 1)[-1].rsplit(".", 1)[0],
extension="md",
content="# Test\n\nbody",
headings=[HeadingRef(heading="Test", level=1)],
content_hash=content_hash,
mtime=now,
ctime=now,
)
@pytest_asyncio.fixture
async def race_user_and_space(async_engine):
"""User + SearchSpace committed via the live engine so the two
concurrent /connect sessions in the race test can both see them.
We can't use the savepoint-trapped ``db_session`` fixture here
because the concurrent sessions need to see committed rows.
"""
user_id = uuid.uuid4()
async with AsyncSession(async_engine) as setup:
user = User(
id=user_id,
email=f"obsidian-race-{uuid.uuid4()}@surfsense.test",
hashed_password="x",
is_active=True,
is_superuser=False,
is_verified=True,
)
space = SearchSpace(name="Race Space", user_id=user_id)
setup.add_all([user, space])
await setup.commit()
await setup.refresh(space)
space_id = space.id
yield user_id, space_id
async with AsyncSession(async_engine) as cleanup:
# Order matters: connectors -> documents -> space -> user. The
# connectors test creates documents, so we wipe them too. The
# CASCADE on user_id catches anything we missed.
await cleanup.execute(
text("DELETE FROM search_source_connectors WHERE user_id = :uid"),
{"uid": user_id},
)
await cleanup.execute(
text("DELETE FROM searchspaces WHERE id = :id"),
{"id": space_id},
)
await cleanup.execute(
text('DELETE FROM "user" WHERE id = :uid'),
{"uid": user_id},
)
await cleanup.commit()
# ---------------------------------------------------------------------------
# /connect race + index enforcement
# ---------------------------------------------------------------------------
class TestConnectRace:
async def test_concurrent_first_connects_collapse_to_one_row(
self, async_engine, race_user_and_space
):
"""Two simultaneous /connect calls for the same vault should
produce exactly one row, not two. Same vault_id + same
fingerprint funnels through both partial unique indexes; the
loser falls back to the survivor row via the IntegrityError
branch in obsidian_connect."""
user_id, space_id = race_user_and_space
vault_id = str(uuid.uuid4())
fingerprint = "fp-" + uuid.uuid4().hex
async def _call(name_suffix: str) -> None:
async with AsyncSession(async_engine) as s:
fresh_user = await s.get(User, user_id)
payload = ConnectRequest(
vault_id=vault_id,
vault_name=f"My Vault {name_suffix}",
search_space_id=space_id,
vault_fingerprint=fingerprint,
)
await obsidian_connect(payload, user=fresh_user, session=s)
results = await asyncio.gather(_call("a"), _call("b"), return_exceptions=True)
for r in results:
assert not isinstance(r, Exception), f"Connect raised: {r!r}"
async with AsyncSession(async_engine) as verify:
count = (
await verify.execute(
select(func.count(SearchSourceConnector.id)).where(
SearchSourceConnector.user_id == user_id,
)
)
).scalar_one()
assert count == 1
async def test_partial_unique_index_blocks_raw_duplicate(
self, async_engine, race_user_and_space
):
"""Raw INSERTs that bypass the route must still be blocked by
the partial unique indexes from migration 129."""
user_id, space_id = race_user_and_space
vault_id = str(uuid.uuid4())
async with AsyncSession(async_engine) as s:
s.add(
SearchSourceConnector(
name="Obsidian - First",
connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
is_indexable=False,
config={
"vault_id": vault_id,
"vault_name": "First",
"source": "plugin",
"vault_fingerprint": "fp-1",
},
user_id=user_id,
search_space_id=space_id,
)
)
await s.commit()
with pytest.raises(IntegrityError):
async with AsyncSession(async_engine) as s:
s.add(
SearchSourceConnector(
name="Obsidian - Second",
connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
is_indexable=False,
config={
"vault_id": vault_id,
"vault_name": "Second",
"source": "plugin",
"vault_fingerprint": "fp-2",
},
user_id=user_id,
search_space_id=space_id,
)
)
await s.commit()
async def test_fingerprint_blocks_raw_cross_device_duplicate(
self, async_engine, race_user_and_space
):
"""Two connectors for the same user with different vault_ids but
the same fingerprint cannot coexist."""
user_id, space_id = race_user_and_space
fingerprint = "fp-" + uuid.uuid4().hex
async with AsyncSession(async_engine) as s:
s.add(
SearchSourceConnector(
name="Obsidian - Desktop",
connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
is_indexable=False,
config={
"vault_id": str(uuid.uuid4()),
"vault_name": "Vault",
"source": "plugin",
"vault_fingerprint": fingerprint,
},
user_id=user_id,
search_space_id=space_id,
)
)
await s.commit()
with pytest.raises(IntegrityError):
async with AsyncSession(async_engine) as s:
s.add(
SearchSourceConnector(
name="Obsidian - Mobile",
connector_type=SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
is_indexable=False,
config={
"vault_id": str(uuid.uuid4()),
"vault_name": "Vault",
"source": "plugin",
"vault_fingerprint": fingerprint,
},
user_id=user_id,
search_space_id=space_id,
)
)
await s.commit()
async def test_second_device_adopts_existing_connector_via_fingerprint(
self, async_engine, race_user_and_space
):
"""Device A connects with vault_id=A. Device B then connects with
a fresh vault_id=B but the same fingerprint. The route must
return A's identity (not create a B row), proving cross-device
dedup happens transparently to the plugin."""
user_id, space_id = race_user_and_space
vault_id_a = str(uuid.uuid4())
vault_id_b = str(uuid.uuid4())
fingerprint = "fp-" + uuid.uuid4().hex
async with AsyncSession(async_engine) as s:
fresh_user = await s.get(User, user_id)
resp_a = await obsidian_connect(
ConnectRequest(
vault_id=vault_id_a,
vault_name="Shared Vault",
search_space_id=space_id,
vault_fingerprint=fingerprint,
),
user=fresh_user,
session=s,
)
async with AsyncSession(async_engine) as s:
fresh_user = await s.get(User, user_id)
resp_b = await obsidian_connect(
ConnectRequest(
vault_id=vault_id_b,
vault_name="Shared Vault",
search_space_id=space_id,
vault_fingerprint=fingerprint,
),
user=fresh_user,
session=s,
)
assert resp_b.vault_id == vault_id_a
assert resp_b.connector_id == resp_a.connector_id
async with AsyncSession(async_engine) as verify:
count = (
await verify.execute(
select(func.count(SearchSourceConnector.id)).where(
SearchSourceConnector.user_id == user_id,
)
)
).scalar_one()
assert count == 1
# ---------------------------------------------------------------------------
# Combined wire-shape smoke test
# ---------------------------------------------------------------------------
class TestWireContractSmoke:
"""Walks /connect -> /sync -> /rename -> /notes -> /manifest -> /stats
sequentially and asserts each response matches the new schema. With
`response_model=` on every route, FastAPI is already validating the
shape on real traffic; this test mainly guards against accidental
field renames the way the TypeScript decoder would catch them."""
async def test_full_flow_returns_typed_payloads(
self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
vault_id = str(uuid.uuid4())
# 1. /connect
connect_resp = await obsidian_connect(
ConnectRequest(
vault_id=vault_id,
vault_name="Smoke Vault",
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
session=db_session,
)
assert connect_resp.connector_id > 0
assert connect_resp.vault_id == vault_id
assert "sync" in connect_resp.capabilities
assert connect_resp.server_time_utc is not None
# 2. /sync — stub the indexer so the call doesn't drag the LLM /
# embedding pipeline in. We're testing the wire contract, not the
# indexer itself.
fake_doc = type("FakeDoc", (), {"id": 12345})()
with patch(
"app.routes.obsidian_plugin_routes.upsert_note",
new=AsyncMock(return_value=fake_doc),
):
sync_resp = await obsidian_sync(
SyncBatchRequest(
vault_id=vault_id,
notes=[
_make_note_payload(vault_id, "ok.md", "hash-ok"),
_make_note_payload(vault_id, "fail.md", "hash-fail"),
],
),
user=db_user,
session=db_session,
)
assert isinstance(sync_resp, SyncAck)
assert sync_resp.vault_id == vault_id
assert sync_resp.indexed == 2
assert sync_resp.failed == 0
assert len(sync_resp.items) == 2
assert all(it.status == "ok" for it in sync_resp.items)
# The TypeScript decoder filters on items[].status === "error" and
# extracts .path, so confirm both fields are present and named.
assert {it.path for it in sync_resp.items} == {"ok.md", "fail.md"}
# 2b. Re-run /sync but force the indexer to raise on one note so
# the per-item failure decoder gets exercised end-to-end.
async def _selective_upsert(session, *, connector, payload, user_id):
if payload.path == "fail.md":
raise RuntimeError("simulated indexing failure")
return fake_doc
with patch(
"app.routes.obsidian_plugin_routes.upsert_note",
new=AsyncMock(side_effect=_selective_upsert),
):
sync_resp = await obsidian_sync(
SyncBatchRequest(
vault_id=vault_id,
notes=[
_make_note_payload(vault_id, "ok.md", "h1"),
_make_note_payload(vault_id, "fail.md", "h2"),
],
),
user=db_user,
session=db_session,
)
assert sync_resp.indexed == 1
assert sync_resp.failed == 1
statuses = {it.path: it.status for it in sync_resp.items}
assert statuses == {"ok.md": "ok", "fail.md": "error"}
# 3. /rename — patch rename_note so we don't need a real Document.
async def _rename(*args, **kwargs) -> object:
if kwargs.get("old_path") == "missing.md":
return None
return fake_doc
with patch(
"app.routes.obsidian_plugin_routes.rename_note",
new=AsyncMock(side_effect=_rename),
):
rename_resp = await obsidian_rename(
RenameBatchRequest(
vault_id=vault_id,
renames=[
RenameItem(old_path="a.md", new_path="b.md"),
RenameItem(old_path="missing.md", new_path="x.md"),
],
),
user=db_user,
session=db_session,
)
assert isinstance(rename_resp, RenameAck)
assert rename_resp.renamed == 1
assert rename_resp.missing == 1
assert {it.status for it in rename_resp.items} == {"ok", "missing"}
# snake_case fields are deliberate — the plugin decoder maps them
# to camelCase explicitly.
assert all(it.old_path and it.new_path for it in rename_resp.items)
# 4. /notes DELETE
async def _delete(*args, **kwargs) -> bool:
return kwargs.get("path") != "ghost.md"
with patch(
"app.routes.obsidian_plugin_routes.delete_note",
new=AsyncMock(side_effect=_delete),
):
delete_resp = await obsidian_delete_notes(
DeleteBatchRequest(vault_id=vault_id, paths=["b.md", "ghost.md"]),
user=db_user,
session=db_session,
)
assert isinstance(delete_resp, DeleteAck)
assert delete_resp.deleted == 1
assert delete_resp.missing == 1
assert {it.path: it.status for it in delete_resp.items} == {
"b.md": "ok",
"ghost.md": "missing",
}
# 5. /manifest — empty (no real Documents were created because
# upsert_note was mocked) but the response shape is what we care
# about.
manifest_resp = await obsidian_manifest(
vault_id=vault_id, user=db_user, session=db_session
)
assert isinstance(manifest_resp, ManifestResponse)
assert manifest_resp.vault_id == vault_id
assert manifest_resp.items == {}
# 6. /stats — same; row count is 0 because upsert_note was mocked.
stats_resp = await obsidian_stats(
vault_id=vault_id, user=db_user, session=db_session
)
assert isinstance(stats_resp, StatsResponse)
assert stats_resp.vault_id == vault_id
assert stats_resp.files_synced == 0
assert stats_resp.last_sync_at is None
async def test_sync_queues_binary_attachments(
self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
vault_id = str(uuid.uuid4())
await obsidian_connect(
ConnectRequest(
vault_id=vault_id,
vault_name="Queue Vault",
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
session=db_session,
)
fake_doc = type("FakeDoc", (), {"id": 12345})()
binary_note = _make_note_payload(vault_id, "image.png", "hash-bin")
binary_note.extension = "png"
binary_note.is_binary = True
binary_note.binary_base64 = "aGVsbG8="
binary_note.mime_type = "image/png"
binary_note.content = ""
with (
patch(
"app.routes.obsidian_plugin_routes.upsert_note",
new=AsyncMock(return_value=fake_doc),
) as upsert_mock,
patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock,
):
sync_resp = await obsidian_sync(
SyncBatchRequest(
vault_id=vault_id,
notes=[
_make_note_payload(vault_id, "ok.md", "hash-ok"),
binary_note,
],
),
user=db_user,
session=db_session,
)
assert sync_resp.indexed == 2
assert sync_resp.failed == 0
statuses = {it.path: it.status for it in sync_resp.items}
assert statuses == {"ok.md": "ok", "image.png": "queued"}
assert upsert_mock.await_count == 1
queue_mock.assert_called_once()
async def test_sync_rejects_unsupported_attachment_extension(
self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
vault_id = str(uuid.uuid4())
await obsidian_connect(
ConnectRequest(
vault_id=vault_id,
vault_name="Reject Vault",
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
session=db_session,
)
fake_doc = type("FakeDoc", (), {"id": 12345})()
bad_note = _make_note_payload(vault_id, "photo.heic", "hash-heic")
bad_note.extension = "heic"
bad_note.is_binary = True
bad_note.binary_base64 = "aGVsbG8="
bad_note.mime_type = "image/heic"
bad_note.content = ""
with (
patch(
"app.routes.obsidian_plugin_routes.upsert_note",
new=AsyncMock(return_value=fake_doc),
),
patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock,
):
sync_resp = await obsidian_sync(
SyncBatchRequest(
vault_id=vault_id,
notes=[
_make_note_payload(vault_id, "ok.md", "hash-ok"),
bad_note,
],
),
user=db_user,
session=db_session,
)
assert sync_resp.indexed == 1
assert sync_resp.failed == 1
items_by_path = {it.path: it for it in sync_resp.items}
assert items_by_path["ok.md"].status == "ok"
assert items_by_path["photo.heic"].status == "error"
assert "unsupported attachment extension" in (
items_by_path["photo.heic"].error or ""
)
queue_mock.assert_not_called()
async def test_sync_rejects_mime_extension_mismatch(
self, db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
vault_id = str(uuid.uuid4())
await obsidian_connect(
ConnectRequest(
vault_id=vault_id,
vault_name="Mismatch Vault",
search_space_id=db_search_space.id,
vault_fingerprint="fp-" + uuid.uuid4().hex,
),
user=db_user,
session=db_session,
)
fake_doc = type("FakeDoc", (), {"id": 12345})()
mismatched = _make_note_payload(vault_id, "image.png", "hash-png")
mismatched.extension = "png"
mismatched.is_binary = True
mismatched.binary_base64 = "aGVsbG8="
mismatched.mime_type = "application/pdf"
mismatched.content = ""
with (
patch(
"app.routes.obsidian_plugin_routes.upsert_note",
new=AsyncMock(return_value=fake_doc),
),
patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock,
):
sync_resp = await obsidian_sync(
SyncBatchRequest(
vault_id=vault_id,
notes=[
_make_note_payload(vault_id, "ok.md", "hash-ok"),
mismatched,
],
),
user=db_user,
session=db_session,
)
assert sync_resp.indexed == 1
assert sync_resp.failed == 1
items_by_path = {it.path: it for it in sync_resp.items}
assert items_by_path["ok.md"].status == "ok"
assert items_by_path["image.png"].status == "error"
assert "does not match extension" in (
items_by_path["image.png"].error or ""
)
queue_mock.assert_not_called()

View file

@ -79,7 +79,7 @@ async def test_file_write_null_filename_uses_semantic_default_path():
@pytest.mark.asyncio
async def test_file_write_null_filename_infers_json_extension():
async def test_file_write_null_filename_defaults_to_markdown_path():
llm = _FakeLLM(
'{"intent":"file_write","confidence":0.71,"suggested_filename":null}'
)
@ -94,7 +94,7 @@ async def test_file_write_null_filename_infers_json_extension():
assert result is not None
contract = result["file_operation_contract"]
assert contract["intent"] == FileOperationIntent.FILE_WRITE.value
assert contract["suggested_path"] == "/notes.json"
assert contract["suggested_path"] == "/notes.md"
@pytest.mark.asyncio

View file

@ -30,6 +30,7 @@ def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: P
backend = resolver(_RuntimeStub())
assert isinstance(backend, MultiRootLocalFolderBackend)
assert backend.list_mounts() == ("tmp",)
def test_backend_resolver_uses_cloud_mode_by_default():
@ -57,3 +58,4 @@ def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path
backend = resolver(_RuntimeStub())
assert isinstance(backend, MultiRootLocalFolderBackend)
assert backend.list_mounts() == ("resume", "notes")

View file

@ -34,6 +34,11 @@ class _RuntimeNoSuggestedPath:
state = {"file_operation_contract": {}}
class _RuntimeWithSuggestedPath:
def __init__(self, suggested_path: str) -> None:
self.state = {"file_operation_contract": {"suggested_path": suggested_path}}
def test_verify_written_content_prefers_raw_sync() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
expected = "line1\nline2"
@ -162,3 +167,47 @@ def test_normalize_local_mount_path_prefixes_posix_absolute_path_for_linux_and_m
resolved = middleware._normalize_local_mount_path("/var/log/app.log", runtime) # type: ignore[arg-type]
assert resolved == "/pc_backups/var/log/app.log"
def test_normalize_local_mount_path_prefers_unique_existing_parent_mount(
tmp_path: Path,
) -> None:
root_a = tmp_path / "RootA"
root_b = tmp_path / "RootB"
(root_a / "other").mkdir(parents=True)
(root_b / "nested" / "deep").mkdir(parents=True)
backend = MultiRootLocalFolderBackend(
(("root_a", str(root_a)), ("root_b", str(root_b)))
)
runtime = _RuntimeNoSuggestedPath()
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
"/nested/deep/new-note.md",
runtime,
)
assert resolved == "/root_b/nested/deep/new-note.md"
def test_normalize_local_mount_path_uses_suggested_mount_when_ambiguous(
tmp_path: Path,
) -> None:
root_a = tmp_path / "RootA"
root_b = tmp_path / "RootB"
root_a.mkdir(parents=True)
root_b.mkdir(parents=True)
backend = MultiRootLocalFolderBackend(
(("root_a", str(root_a)), ("root_b", str(root_b)))
)
runtime = _RuntimeWithSuggestedPath("/root_b/notes/context.md")
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._get_backend = lambda _runtime: backend # type: ignore[method-assign]
resolved = middleware._normalize_local_mount_path( # type: ignore[arg-type]
"/brand-new-note.md",
runtime,
)
assert resolved == "/root_b/brand-new-note.md"

View file

@ -9,6 +9,7 @@ pytestmark = pytest.mark.unit
def test_local_backend_write_read_edit_roundtrip(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "notes").mkdir()
write = backend.write("/notes/test.md", "line1\nline2")
assert write.error is None
@ -51,9 +52,20 @@ def test_local_backend_glob_and_grep(tmp_path: Path):
def test_local_backend_read_raw_returns_exact_content(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "notes").mkdir()
expected = "# Title\n\nline 1\nline 2\n"
write = backend.write("/notes/raw.md", expected)
assert write.error is None
raw = backend.read_raw("/notes/raw.md")
assert raw == expected
def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
write = backend.write("/tempoo/new-note.md", "# New note")
assert write.error is not None
assert "parent directory" in write.error
assert not (tmp_path / "tempoo").exists()

View file

@ -26,3 +26,12 @@ def test_mount_ids_preserve_client_mapping_order(tmp_path: Path) -> None:
)
assert backend.list_mounts() == ("pc_backups", "pc_backups_2", "notes_2026")
def test_mount_id_is_authoritative_not_folder_name(tmp_path: Path) -> None:
root = tmp_path / "Resume Folder"
root.mkdir()
backend = MultiRootLocalFolderBackend((("custom_resume_mount", str(root)),))
assert backend.list_mounts() == ("custom_resume_mount",)

View file

@ -0,0 +1,225 @@
from __future__ import annotations
import base64
from datetime import UTC, datetime
import pytest
from pydantic import ValidationError
from app.etl_pipeline.etl_document import EtlResult
from app.schemas.obsidian_plugin import HeadingRef, NotePayload
from app.services.obsidian_plugin_indexer import (
_build_metadata,
_extract_binary_attachment_markdown,
_is_image_attachment,
_require_extracted_attachment_content,
)
_FAKE_PNG_B64 = base64.b64encode(b"\x89PNG\r\n\x1a\n").decode("ascii")
def test_build_metadata_serializes_headings_to_plain_json() -> None:
now = datetime.now(UTC)
payload = NotePayload(
vault_id="vault-1",
path="notes.md",
name="notes",
extension="md",
content="# Notes",
headings=[HeadingRef(heading="Notes", level=1)],
content_hash="abc123",
mtime=now,
ctime=now,
)
metadata = _build_metadata(payload, vault_name="My Vault", connector_id=42)
assert metadata["headings"] == [{"heading": "Notes", "level": 1}]
def test_build_metadata_marks_binary_attachment_fields() -> None:
now = datetime.now(UTC)
payload = NotePayload(
vault_id="vault-1",
path="assets/diagram.png",
name="diagram",
extension="png",
content="",
content_hash="abc123",
mtime=now,
ctime=now,
is_binary=True,
binary_base64=_FAKE_PNG_B64,
mime_type="image/png",
)
metadata = _build_metadata(payload, vault_name="My Vault", connector_id=42)
assert metadata["is_binary"] is True
assert metadata["mime_type"] == "image/png"
@pytest.mark.asyncio
async def test_extract_binary_attachment_markdown_handles_invalid_base64() -> None:
now = datetime.now(UTC)
payload = NotePayload(
vault_id="vault-1",
path="assets/diagram.png",
name="diagram",
extension="png",
content="",
content_hash="abc123",
mtime=now,
ctime=now,
is_binary=True,
binary_base64="not-valid-base64!!",
mime_type="image/png",
)
content, metadata = await _extract_binary_attachment_markdown(
payload, vision_llm=None
)
assert content == ""
assert metadata["attachment_extraction_status"] == "invalid_binary_payload"
@pytest.mark.asyncio
async def test_extract_binary_attachment_markdown_uses_etl(monkeypatch) -> None:
now = datetime.now(UTC)
payload = NotePayload(
vault_id="vault-1",
path="assets/spec.pdf",
name="spec",
extension="pdf",
content="",
content_hash="abc123",
mtime=now,
ctime=now,
is_binary=True,
binary_base64=base64.b64encode(b"%PDF-1.7 fake bytes").decode("ascii"),
mime_type="application/pdf",
)
async def _fake_run_etl_extract( # noqa: ANN001
*, file_path, filename, vision_llm
):
assert filename == "spec.pdf"
assert file_path
assert vision_llm is None
return EtlResult(
markdown_content="Extracted content",
etl_service="TEST_ETL",
content_type="document",
)
monkeypatch.setattr(
"app.services.obsidian_plugin_indexer._run_etl_extract",
_fake_run_etl_extract,
)
content, metadata = await _extract_binary_attachment_markdown(
payload, vision_llm=None
)
assert content == "Extracted content"
assert metadata["attachment_extraction_status"] == "ok"
assert metadata["attachment_etl_service"] == "TEST_ETL"
def test_is_image_attachment_detects_image_extensions() -> None:
now = datetime.now(UTC)
image_payload = NotePayload(
vault_id="vault-1",
path="assets/screenshot.PNG",
name="screenshot",
extension="PNG",
content="",
content_hash="abc123",
mtime=now,
ctime=now,
is_binary=True,
binary_base64=_FAKE_PNG_B64,
mime_type="image/png",
)
pdf_payload = NotePayload(
vault_id="vault-1",
path="assets/spec.pdf",
name="spec",
extension="pdf",
content="",
content_hash="abc123",
mtime=now,
ctime=now,
is_binary=True,
binary_base64=_FAKE_PNG_B64,
mime_type="application/pdf",
)
assert _is_image_attachment(image_payload) is True
assert _is_image_attachment(pdf_payload) is False
def test_note_payload_rejects_binary_without_base64() -> None:
now = datetime.now(UTC)
with pytest.raises(ValidationError, match="binary_base64 is required"):
NotePayload(
vault_id="vault-1",
path="assets/diagram.png",
name="diagram",
extension="png",
content="",
content_hash="abc123",
mtime=now,
ctime=now,
is_binary=True,
mime_type="image/png",
)
def test_note_payload_rejects_binary_without_mime_type() -> None:
now = datetime.now(UTC)
with pytest.raises(ValidationError, match="mime_type is required"):
NotePayload(
vault_id="vault-1",
path="assets/diagram.png",
name="diagram",
extension="png",
content="",
content_hash="abc123",
mtime=now,
ctime=now,
is_binary=True,
binary_base64=_FAKE_PNG_B64,
)
def test_note_payload_rejects_markdown_with_binary_fields() -> None:
now = datetime.now(UTC)
with pytest.raises(
ValidationError,
match="binary_base64 and mime_type must be omitted when is_binary is False",
):
NotePayload(
vault_id="vault-1",
path="notes.md",
name="notes",
extension="md",
content="# Notes",
content_hash="abc123",
mtime=now,
ctime=now,
binary_base64=_FAKE_PNG_B64,
)
def test_require_extracted_attachment_content_rejects_empty_content() -> None:
with pytest.raises(
RuntimeError, match="Attachment extraction failed for assets/img.png"
):
_require_extracted_attachment_content(
content=" ",
etl_meta={"attachment_extraction_status": "etl_failed"},
path="assets/img.png",
)