mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
feat(busy_mutex): enhance thread lock management to prevent stale middleware interference
This commit is contained in:
parent
f65b3be1ce
commit
25ccc959cf
5 changed files with 134 additions and 9 deletions
|
|
@ -61,6 +61,9 @@ class _ThreadLockManager:
|
|||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
# Monotonic per-thread epoch used to prevent stale middleware
|
||||
# teardown from releasing a newer turn's lock.
|
||||
self._turn_epoch: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
|
|
@ -107,6 +110,14 @@ class _ThreadLockManager:
|
|||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||
self._turn_epoch[thread_id] = epoch
|
||||
return epoch
|
||||
|
||||
def current_turn_epoch(self, thread_id: str) -> int:
|
||||
return self._turn_epoch.get(thread_id, 0)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
|
|
@ -114,6 +125,10 @@ class _ThreadLockManager:
|
|||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||
# retry that already acquired the lock for the same thread.
|
||||
self.bump_turn_epoch(thread_id)
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
|
|
@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
super().__init__()
|
||||
self._require_thread_id = require_thread_id
|
||||
self.tools = []
|
||||
# Per-call locks owned by this middleware. We track them as
|
||||
# an instance attribute so ``aafter_agent`` knows which lock
|
||||
# to release.
|
||||
self._held_locks: dict[str, asyncio.Lock] = {}
|
||||
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||
# only releases when its epoch still matches the manager's current
|
||||
# epoch for the thread, preventing stale unlock races.
|
||||
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||
|
|
@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
await lock.acquire()
|
||||
self._held_locks[thread_id] = lock
|
||||
epoch = manager.bump_turn_epoch(thread_id)
|
||||
self._held_locks[thread_id] = (lock, epoch)
|
||||
# Reset the cancel event so this turn starts fresh
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
|
@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
return None
|
||||
lock = self._held_locks.pop(thread_id, None)
|
||||
if lock is not None and lock.locked():
|
||||
held = self._held_locks.pop(thread_id, None)
|
||||
if held is None:
|
||||
return None
|
||||
lock, held_epoch = held
|
||||
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||
# already advanced epoch). Do not touch current lock/cancel state.
|
||||
return None
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
# Always clear cancel event between turns so a stale signal
|
||||
# doesn't leak into the next request.
|
||||
|
|
|
|||
|
|
@ -179,6 +179,7 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
user_id: str | UUID | None,
|
||||
selected_llm_config_id: int,
|
||||
force_repin_free: bool = False,
|
||||
exclude_config_ids: set[int] | None = None,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||
|
||||
|
|
@ -214,9 +215,14 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
from_existing_pin=False,
|
||||
)
|
||||
|
||||
candidates = _global_candidates()
|
||||
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||
candidates = [
|
||||
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
|
||||
]
|
||||
if not candidates:
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
raise ValueError(
|
||||
"No usable global LLM configs are available for Auto mode"
|
||||
)
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||
|
|
|
|||
|
|
@ -2784,6 +2784,10 @@ async def stream_new_chat(
|
|||
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
# The failed attempt may still hold the per-thread busy mutex
|
||||
# (middleware teardown can lag behind raised provider errors).
|
||||
# Force release before we retry within the same request.
|
||||
end_turn(str(chat_id))
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id,
|
||||
reason="provider_rate_limited",
|
||||
|
|
@ -2796,6 +2800,7 @@ async def stream_new_chat(
|
|||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={previous_config_id},
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
|
||||
|
|
@ -3442,6 +3447,9 @@ async def stream_resume_chat(
|
|||
|
||||
runtime_rate_limit_recovered = True
|
||||
previous_config_id = llm_config_id
|
||||
# Ensure the same-request recovery retry does not trip the
|
||||
# BusyMutex lock retained by the failed attempt.
|
||||
end_turn(str(chat_id))
|
||||
mark_runtime_cooldown(
|
||||
previous_config_id,
|
||||
reason="provider_rate_limited",
|
||||
|
|
@ -3453,6 +3461,7 @@ async def stream_resume_chat(
|
|||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={previous_config_id},
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
|
||||
|
|
|
|||
|
|
@ -118,3 +118,37 @@ async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
|
|||
assert not manager.lock_for(thread_id).locked()
|
||||
assert not get_cancel_event(thread_id).is_set()
|
||||
assert is_cancel_requested(thread_id) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
|
||||
"""A stale aafter call from attempt A must not unlock attempt B.
|
||||
|
||||
Repro flow:
|
||||
1) attempt A acquires thread lock
|
||||
2) forced end_turn clears A so retry can proceed
|
||||
3) attempt B acquires same thread lock
|
||||
4) stale attempt-A aafter runs late
|
||||
|
||||
Expected: B lock remains held.
|
||||
"""
|
||||
thread_id = "stale-aafter-lock"
|
||||
runtime = _Runtime(thread_id)
|
||||
attempt_a = BusyMutexMiddleware()
|
||||
attempt_b = BusyMutexMiddleware()
|
||||
|
||||
await attempt_a.abefore_agent({}, runtime)
|
||||
lock = manager.lock_for(thread_id)
|
||||
assert lock.locked()
|
||||
|
||||
end_turn(thread_id)
|
||||
assert not lock.locked()
|
||||
|
||||
await attempt_b.abefore_agent({}, runtime)
|
||||
assert lock.locked()
|
||||
|
||||
# Stale cleanup from attempt A must not release attempt B's lock.
|
||||
await attempt_a.aafter_agent({}, runtime)
|
||||
assert lock.locked()
|
||||
|
||||
await attempt_b.aafter_agent({}, runtime)
|
||||
|
|
|
|||
|
|
@ -813,3 +813,56 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
|||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch):
|
||||
"""Runtime retry should never repin the just-failed config."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 90,
|
||||
"health_gated": False,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "C",
|
||||
"quality_score": 80,
|
||||
"health_gated": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={-1},
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue