feat: enhance Obsidian plugin routes with for_update parameter for improved concurrency handling

This commit is contained in:
Anish Sarkar 2026-04-20 19:24:36 +05:30
parent b5c9388c8a
commit 87150a6d7f
3 changed files with 46 additions and 42 deletions

View file

@ -96,22 +96,29 @@ async def _resolve_vault_connector(
*, *,
user: User, user: User,
vault_id: str, vault_id: str,
for_update: bool = False,
) -> SearchSourceConnector: ) -> SearchSourceConnector:
"""Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user.""" """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user.
result = await session.execute(
select(SearchSourceConnector).where( Callers that mutate ``connector.config`` MUST pass ``for_update=True`` or
and_( concurrent heartbeats will race and lose writes on ``config.devices`` /
SearchSourceConnector.user_id == user.id, ``config.files_synced``.
SearchSourceConnector.connector_type """
== SearchSourceConnectorType.OBSIDIAN_CONNECTOR, stmt = select(SearchSourceConnector).where(
) and_(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
SearchSourceConnector.config["vault_id"].astext == vault_id,
SearchSourceConnector.config["source"].astext == "plugin",
) )
) )
candidates = result.scalars().all() if for_update:
for connector in candidates: stmt = stmt.with_for_update()
cfg = connector.config or {}
if cfg.get("vault_id") == vault_id and cfg.get("source") == "plugin": connector = (await session.execute(stmt)).scalars().first()
return connector if connector is not None:
return connector
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -182,21 +189,26 @@ async def obsidian_connect(
session, user=user, search_space_id=payload.search_space_id session, user=user, search_space_id=payload.search_space_id
) )
result = await session.execute( # FOR UPDATE so concurrent heartbeats can't clobber each other's device entry.
select(SearchSourceConnector).where( existing: SearchSourceConnector | None = (
and_( (
SearchSourceConnector.user_id == user.id, await session.execute(
SearchSourceConnector.connector_type select(SearchSourceConnector)
== SearchSourceConnectorType.OBSIDIAN_CONNECTOR, .where(
and_(
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.OBSIDIAN_CONNECTOR,
SearchSourceConnector.config["vault_id"].astext
== payload.vault_id,
)
)
.with_for_update()
) )
) )
.scalars()
.first()
) )
existing: SearchSourceConnector | None = None
for candidate in result.scalars().all():
cfg = candidate.config or {}
if cfg.get("vault_id") == payload.vault_id:
existing = candidate
break
now_iso = datetime.now(UTC).isoformat() now_iso = datetime.now(UTC).isoformat()
@ -210,12 +222,9 @@ async def obsidian_connect(
"source": "plugin", "source": "plugin",
"plugin_version": payload.plugin_version, "plugin_version": payload.plugin_version,
"devices": devices, "devices": devices,
"device_count": len(devices),
"last_connect_at": now_iso, "last_connect_at": now_iso,
} }
) )
cfg.pop("legacy", None)
cfg.pop("vault_path", None)
existing.config = cfg existing.config = cfg
# Re-stamp on every connect so vault renames in Obsidian propagate; # Re-stamp on every connect so vault renames in Obsidian propagate;
# the web UI hides the Name input for Obsidian connectors. # the web UI hides the Name input for Obsidian connectors.
@ -237,7 +246,6 @@ async def obsidian_connect(
"source": "plugin", "source": "plugin",
"plugin_version": payload.plugin_version, "plugin_version": payload.plugin_version,
"devices": devices, "devices": devices,
"device_count": len(devices),
"files_synced": 0, "files_synced": 0,
"last_connect_at": now_iso, "last_connect_at": now_iso,
}, },
@ -264,7 +272,7 @@ async def obsidian_sync(
) -> dict[str, object]: ) -> dict[str, object]:
"""Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry.""" """Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry."""
connector = await _resolve_vault_connector( connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id session, user=user, vault_id=payload.vault_id, for_update=True
) )
results: list[dict[str, object]] = [] results: list[dict[str, object]] = []
@ -315,7 +323,7 @@ async def obsidian_rename(
) -> dict[str, object]: ) -> dict[str, object]:
"""Apply a batch of vault rename events.""" """Apply a batch of vault rename events."""
connector = await _resolve_vault_connector( connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id session, user=user, vault_id=payload.vault_id, for_update=True
) )
results: list[dict[str, object]] = [] results: list[dict[str, object]] = []
@ -382,7 +390,7 @@ async def obsidian_delete_notes(
) -> dict[str, object]: ) -> dict[str, object]:
"""Soft-delete a batch of notes by vault-relative path.""" """Soft-delete a batch of notes by vault-relative path."""
connector = await _resolve_vault_connector( connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id session, user=user, vault_id=payload.vault_id, for_update=True
) )
deleted = 0 deleted = 0

View file

@ -108,12 +108,11 @@ export class SurfSenseApiClient {
}): Promise<ConnectResponse> { }): Promise<ConnectResponse> {
return await this.request<ConnectResponse>( return await this.request<ConnectResponse>(
"POST", "POST",
`/api/v1/obsidian/connect?search_space_id=${encodeURIComponent( "/api/v1/obsidian/connect",
String(input.searchSpaceId)
)}`,
{ {
vault_id: input.vaultId, vault_id: input.vaultId,
vault_name: input.vaultName, vault_name: input.vaultName,
search_space_id: input.searchSpaceId,
plugin_version: this.opts.pluginVersion, plugin_version: this.opts.pluginVersion,
device_id: input.deviceId, device_id: input.deviceId,
} }

View file

@ -100,14 +100,11 @@ const LegacyBanner: FC = () => {
const PluginStats: FC<{ config: Record<string, unknown> }> = ({ config }) => { const PluginStats: FC<{ config: Record<string, unknown> }> = ({ config }) => {
const stats: { label: string; value: string }[] = useMemo(() => { const stats: { label: string; value: string }[] = useMemo(() => {
const filesSynced = config.files_synced; const filesSynced = config.files_synced;
// Prefer the stamped count; fall back to len(devices) for rows the // Derive from config.devices — a stored counter could drift under concurrent heartbeats.
// backend hasn't re-stamped yet.
const deviceCount = const deviceCount =
typeof config.device_count === "number" config.devices && typeof config.devices === "object"
? config.device_count ? Object.keys(config.devices as Record<string, unknown>).length
: config.devices && typeof config.devices === "object" : null;
? Object.keys(config.devices as Record<string, unknown>).length
: null;
return [ return [
{ label: "Vault", value: (config.vault_name as string) || "—" }, { label: "Vault", value: (config.vault_name as string) || "—" },
{ {