From 87150a6d7f7a082627c90f16e9e1531e659b8dde Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Mon, 20 Apr 2026 19:24:36 +0530 Subject: [PATCH] feat: enhance Obsidian plugin routes with for_update parameter for improved concurrency handling --- .../app/routes/obsidian_plugin_routes.py | 72 ++++++++++--------- surfsense_obsidian/src/api-client.ts | 5 +- .../components/obsidian-config.tsx | 11 ++- 3 files changed, 46 insertions(+), 42 deletions(-) diff --git a/surfsense_backend/app/routes/obsidian_plugin_routes.py b/surfsense_backend/app/routes/obsidian_plugin_routes.py index 0d2ce703d..4315e0d33 100644 --- a/surfsense_backend/app/routes/obsidian_plugin_routes.py +++ b/surfsense_backend/app/routes/obsidian_plugin_routes.py @@ -96,22 +96,29 @@ async def _resolve_vault_connector( *, user: User, vault_id: str, + for_update: bool = False, ) -> SearchSourceConnector: - """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user.""" - result = await session.execute( - select(SearchSourceConnector).where( - and_( - SearchSourceConnector.user_id == user.id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.OBSIDIAN_CONNECTOR, - ) + """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user. + + Callers that mutate ``connector.config`` MUST pass ``for_update=True`` or + concurrent heartbeats will race and lose writes on ``config.devices`` / + ``config.files_synced``. + """ + stmt = select(SearchSourceConnector).where( + and_( + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.OBSIDIAN_CONNECTOR, + SearchSourceConnector.config["vault_id"].astext == vault_id, + SearchSourceConnector.config["source"].astext == "plugin", ) ) - candidates = result.scalars().all() - for connector in candidates: - cfg = connector.config or {} - if cfg.get("vault_id") == vault_id and cfg.get("source") == "plugin": - return connector + if for_update: + stmt = stmt.with_for_update() + + connector = (await session.execute(stmt)).scalars().first() + if connector is not None: + return connector raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -182,21 +189,26 @@ async def obsidian_connect( session, user=user, search_space_id=payload.search_space_id ) - result = await session.execute( - select(SearchSourceConnector).where( - and_( - SearchSourceConnector.user_id == user.id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.OBSIDIAN_CONNECTOR, + # FOR UPDATE so concurrent heartbeats can't clobber each other's device entry. + existing: SearchSourceConnector | None = ( + ( + await session.execute( + select(SearchSourceConnector) + .where( + and_( + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.OBSIDIAN_CONNECTOR, + SearchSourceConnector.config["vault_id"].astext + == payload.vault_id, + ) + ) + .with_for_update() ) ) + .scalars() + .first() ) - existing: SearchSourceConnector | None = None - for candidate in result.scalars().all(): - cfg = candidate.config or {} - if cfg.get("vault_id") == payload.vault_id: - existing = candidate - break now_iso = datetime.now(UTC).isoformat() @@ -210,12 +222,9 @@ async def obsidian_connect( "source": "plugin", "plugin_version": payload.plugin_version, "devices": devices, - "device_count": len(devices), "last_connect_at": now_iso, } ) - cfg.pop("legacy", None) - cfg.pop("vault_path", None) existing.config = cfg # Re-stamp on every connect so vault renames in Obsidian propagate; # the web UI hides the Name input for Obsidian connectors. @@ -237,7 +246,6 @@ async def obsidian_connect( "source": "plugin", "plugin_version": payload.plugin_version, "devices": devices, - "device_count": len(devices), "files_synced": 0, "last_connect_at": now_iso, }, @@ -264,7 +272,7 @@ async def obsidian_sync( ) -> dict[str, object]: """Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, user=user, vault_id=payload.vault_id, for_update=True ) results: list[dict[str, object]] = [] @@ -315,7 +323,7 @@ async def obsidian_rename( ) -> dict[str, object]: """Apply a batch of vault rename events.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, user=user, vault_id=payload.vault_id, for_update=True ) results: list[dict[str, object]] = [] @@ -382,7 +390,7 @@ async def obsidian_delete_notes( ) -> dict[str, object]: """Soft-delete a batch of notes by vault-relative path.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, user=user, vault_id=payload.vault_id, for_update=True ) deleted = 0 diff --git a/surfsense_obsidian/src/api-client.ts b/surfsense_obsidian/src/api-client.ts index 4b5ae0e33..fcc5c3c49 100644 --- a/surfsense_obsidian/src/api-client.ts +++ b/surfsense_obsidian/src/api-client.ts @@ -108,12 +108,11 @@ export class SurfSenseApiClient { }): Promise { return await this.request( "POST", - `/api/v1/obsidian/connect?search_space_id=${encodeURIComponent( - String(input.searchSpaceId) - )}`, + "/api/v1/obsidian/connect", { vault_id: input.vaultId, vault_name: input.vaultName, + search_space_id: input.searchSpaceId, plugin_version: this.opts.pluginVersion, device_id: input.deviceId, } diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx index feca9c35e..a828d017a 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/obsidian-config.tsx @@ -100,14 +100,11 @@ const LegacyBanner: FC = () => { const PluginStats: FC<{ config: Record }> = ({ config }) => { const stats: { label: string; value: string }[] = useMemo(() => { const filesSynced = config.files_synced; - // Prefer the stamped count; fall back to len(devices) for rows the - // backend hasn't re-stamped yet. + // Derive from config.devices — a stored counter could drift under concurrent heartbeats. const deviceCount = - typeof config.device_count === "number" - ? config.device_count - : config.devices && typeof config.devices === "object" - ? Object.keys(config.devices as Record).length - : null; + config.devices && typeof config.devices === "object" + ? Object.keys(config.devices as Record).length + : null; return [ { label: "Vault", value: (config.vault_name as string) || "—" }, {