feat(busy_mutex): enhance thread lock management to prevent stale middleware interference

This commit is contained in:
Anish Sarkar 2026-05-02 01:35:30 +05:30
parent f65b3be1ce
commit 25ccc959cf
5 changed files with 134 additions and 9 deletions

View file

@ -61,6 +61,9 @@ class _ThreadLockManager:
self._cancel_events: dict[str, asyncio.Event] = {}
self._cancel_requested_at_ms: dict[str, int] = {}
self._cancel_attempt_count: dict[str, int] = {}
# Monotonic per-thread epoch used to prevent stale middleware
# teardown from releasing a newer turn's lock.
self._turn_epoch: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id)
@ -107,6 +110,14 @@ class _ThreadLockManager:
self._cancel_requested_at_ms.pop(thread_id, None)
self._cancel_attempt_count.pop(thread_id, None)
def bump_turn_epoch(self, thread_id: str) -> int:
epoch = self._turn_epoch.get(thread_id, 0) + 1
self._turn_epoch[thread_id] = epoch
return epoch
def current_turn_epoch(self, thread_id: str) -> int:
return self._turn_epoch.get(thread_id, 0)
def end_turn(self, thread_id: str) -> None:
"""Best-effort terminal cleanup for a thread turn.
@ -114,6 +125,10 @@ class _ThreadLockManager:
finally-blocks where middleware teardown might be skipped due to abort
or disconnect edge-cases.
"""
# Invalidate any in-flight middleware holder first. This guarantees a
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
# retry that already acquired the lock for the same thread.
self.bump_turn_epoch(thread_id)
lock = self._locks.get(thread_id)
if lock is not None and lock.locked():
lock.release()
@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
super().__init__()
self._require_thread_id = require_thread_id
self.tools = []
# Per-call locks owned by this middleware. We track them as
# an instance attribute so ``aafter_agent`` knows which lock
# to release.
self._held_locks: dict[str, asyncio.Lock] = {}
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
# only releases when its epoch still matches the manager's current
# epoch for the thread, preventing stale unlock races.
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
@staticmethod
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
if lock.locked():
raise BusyError(request_id=thread_id)
await lock.acquire()
self._held_locks[thread_id] = lock
epoch = manager.bump_turn_epoch(thread_id)
self._held_locks[thread_id] = (lock, epoch)
# Reset the cancel event so this turn starts fresh
reset_cancel(thread_id)
return None
@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
thread_id = self._thread_id(runtime)
if thread_id is None:
return None
lock = self._held_locks.pop(thread_id, None)
if lock is not None and lock.locked():
held = self._held_locks.pop(thread_id, None)
if held is None:
return None
lock, held_epoch = held
if held_epoch != manager.current_turn_epoch(thread_id):
# Stale teardown from an older attempt (e.g. runtime-recovery path
# already advanced epoch). Do not touch current lock/cancel state.
return None
if lock.locked():
lock.release()
# Always clear cancel event between turns so a stale signal
# doesn't leak into the next request.

View file

@ -179,6 +179,7 @@ async def resolve_or_get_pinned_llm_config_id(
user_id: str | UUID | None,
selected_llm_config_id: int,
force_repin_free: bool = False,
exclude_config_ids: set[int] | None = None,
) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
@ -214,9 +215,14 @@ async def resolve_or_get_pinned_llm_config_id(
from_existing_pin=False,
)
candidates = _global_candidates()
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
]
if not candidates:
raise ValueError("No usable global LLM configs are available for Auto mode")
raise ValueError(
"No usable global LLM configs are available for Auto mode"
)
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent

View file

@ -2784,6 +2784,10 @@ async def stream_new_chat(
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
# The failed attempt may still hold the per-thread busy mutex
# (middleware teardown can lag behind raised provider errors).
# Force release before we retry within the same request.
end_turn(str(chat_id))
mark_runtime_cooldown(
previous_config_id,
reason="provider_rate_limited",
@ -2796,6 +2800,7 @@ async def stream_new_chat(
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
)
).resolved_llm_config_id
@ -3442,6 +3447,9 @@ async def stream_resume_chat(
runtime_rate_limit_recovered = True
previous_config_id = llm_config_id
# Ensure the same-request recovery retry does not trip the
# BusyMutex lock retained by the failed attempt.
end_turn(str(chat_id))
mark_runtime_cooldown(
previous_config_id,
reason="provider_rate_limited",
@ -3453,6 +3461,7 @@ async def stream_resume_chat(
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
)
).resolved_llm_config_id

View file

@ -118,3 +118,37 @@ async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
assert not manager.lock_for(thread_id).locked()
assert not get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is False
@pytest.mark.asyncio
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
"""A stale aafter call from attempt A must not unlock attempt B.
Repro flow:
1) attempt A acquires thread lock
2) forced end_turn clears A so retry can proceed
3) attempt B acquires same thread lock
4) stale attempt-A aafter runs late
Expected: B lock remains held.
"""
thread_id = "stale-aafter-lock"
runtime = _Runtime(thread_id)
attempt_a = BusyMutexMiddleware()
attempt_b = BusyMutexMiddleware()
await attempt_a.abefore_agent({}, runtime)
lock = manager.lock_for(thread_id)
assert lock.locked()
end_turn(thread_id)
assert not lock.locked()
await attempt_b.abefore_agent({}, runtime)
assert lock.locked()
# Stale cleanup from attempt A must not release attempt B's lock.
await attempt_a.aafter_agent({}, runtime)
assert lock.locked()
await attempt_b.aafter_agent({}, runtime)

View file

@ -813,3 +813,56 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
@pytest.mark.asyncio
async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch):
"""Runtime retry should never repin the just-failed config."""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 80,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
exclude_config_ids={-1},
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False