diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index d61a56533..06a27bc96 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -61,6 +61,9 @@ class _ThreadLockManager: self._cancel_events: dict[str, asyncio.Event] = {} self._cancel_requested_at_ms: dict[str, int] = {} self._cancel_attempt_count: dict[str, int] = {} + # Monotonic per-thread epoch used to prevent stale middleware + # teardown from releasing a newer turn's lock. + self._turn_epoch: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -107,6 +110,14 @@ class _ThreadLockManager: self._cancel_requested_at_ms.pop(thread_id, None) self._cancel_attempt_count.pop(thread_id, None) + def bump_turn_epoch(self, thread_id: str) -> int: + epoch = self._turn_epoch.get(thread_id, 0) + 1 + self._turn_epoch[thread_id] = epoch + return epoch + + def current_turn_epoch(self, thread_id: str) -> int: + return self._turn_epoch.get(thread_id, 0) + def end_turn(self, thread_id: str) -> None: """Best-effort terminal cleanup for a thread turn. @@ -114,6 +125,10 @@ class _ThreadLockManager: finally-blocks where middleware teardown might be skipped due to abort or disconnect edge-cases. """ + # Invalidate any in-flight middleware holder first. This guarantees a + # stale ``aafter_agent`` from an older attempt cannot unlock a newer + # retry that already acquired the lock for the same thread. + self.bump_turn_epoch(thread_id) lock = self._locks.get(thread_id) if lock is not None and lock.locked(): lock.release() @@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo super().__init__() self._require_thread_id = require_thread_id self.tools = [] - # Per-call locks owned by this middleware. We track them as - # an instance attribute so ``aafter_agent`` knows which lock - # to release. - self._held_locks: dict[str, asyncio.Lock] = {} + # Per-call lock ownership tracked as (lock, epoch). ``aafter_agent`` + # only releases when its epoch still matches the manager's current + # epoch for the thread, preventing stale unlock races. + self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {} @staticmethod def _thread_id(runtime: Runtime[ContextT]) -> str | None: @@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo if lock.locked(): raise BusyError(request_id=thread_id) await lock.acquire() - self._held_locks[thread_id] = lock + epoch = manager.bump_turn_epoch(thread_id) + self._held_locks[thread_id] = (lock, epoch) # Reset the cancel event so this turn starts fresh reset_cancel(thread_id) return None @@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo thread_id = self._thread_id(runtime) if thread_id is None: return None - lock = self._held_locks.pop(thread_id, None) - if lock is not None and lock.locked(): + held = self._held_locks.pop(thread_id, None) + if held is None: + return None + lock, held_epoch = held + if held_epoch != manager.current_turn_epoch(thread_id): + # Stale teardown from an older attempt (e.g. runtime-recovery path + # already advanced epoch). Do not touch current lock/cancel state. + return None + if lock.locked(): lock.release() # Always clear cancel event between turns so a stale signal # doesn't leak into the next request. diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 05a54b257..f6a223866 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -179,6 +179,7 @@ async def resolve_or_get_pinned_llm_config_id( user_id: str | UUID | None, selected_llm_config_id: int, force_repin_free: bool = False, + exclude_config_ids: set[int] | None = None, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist the pin. @@ -214,9 +215,14 @@ async def resolve_or_get_pinned_llm_config_id( from_existing_pin=False, ) - candidates = _global_candidates() + excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} + candidates = [ + c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids + ] if not candidates: - raise ValueError("No usable global LLM configs are available for Auto mode") + raise ValueError( + "No usable global LLM configs are available for Auto mode" + ) candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 8f596927d..dbfd5e2ea 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2784,6 +2784,10 @@ async def stream_new_chat( runtime_rate_limit_recovered = True previous_config_id = llm_config_id + # The failed attempt may still hold the per-thread busy mutex + # (middleware teardown can lag behind raised provider errors). + # Force release before we retry within the same request. + end_turn(str(chat_id)) mark_runtime_cooldown( previous_config_id, reason="provider_rate_limited", @@ -2796,6 +2800,7 @@ async def stream_new_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, ) ).resolved_llm_config_id @@ -3442,6 +3447,9 @@ async def stream_resume_chat( runtime_rate_limit_recovered = True previous_config_id = llm_config_id + # Ensure the same-request recovery retry does not trip the + # BusyMutex lock retained by the failed attempt. + end_turn(str(chat_id)) mark_runtime_cooldown( previous_config_id, reason="provider_rate_limited", @@ -3453,6 +3461,7 @@ async def stream_resume_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, ) ).resolved_llm_config_id diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index c923dc499..f0161f605 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -118,3 +118,37 @@ async def test_end_turn_force_clears_lock_and_cancel_state() -> None: assert not manager.lock_for(thread_id).locked() assert not get_cancel_event(thread_id).is_set() assert is_cancel_requested(thread_id) is False + + +@pytest.mark.asyncio +async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None: + """A stale aafter call from attempt A must not unlock attempt B. + + Repro flow: + 1) attempt A acquires thread lock + 2) forced end_turn clears A so retry can proceed + 3) attempt B acquires same thread lock + 4) stale attempt-A aafter runs late + + Expected: B lock remains held. + """ + thread_id = "stale-aafter-lock" + runtime = _Runtime(thread_id) + attempt_a = BusyMutexMiddleware() + attempt_b = BusyMutexMiddleware() + + await attempt_a.abefore_agent({}, runtime) + lock = manager.lock_for(thread_id) + assert lock.locked() + + end_turn(thread_id) + assert not lock.locked() + + await attempt_b.abefore_agent({}, runtime) + assert lock.locked() + + # Stale cleanup from attempt A must not release attempt B's lock. + await attempt_a.aafter_agent({}, runtime) + assert lock.locked() + + await attempt_b.aafter_agent({}, runtime) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 8261fdfe0..8696a8829 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -813,3 +813,56 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): ) assert result.resolved_llm_config_id == -1 assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch): + """Runtime retry should never repin the just-failed config.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + exclude_config_ids={-1}, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False