mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 13:52:40 +02:00
feat(busy_mutex): enhance thread lock management to prevent stale middleware interference
This commit is contained in:
parent
f65b3be1ce
commit
25ccc959cf
5 changed files with 134 additions and 9 deletions
|
|
@ -61,6 +61,9 @@ class _ThreadLockManager:
|
|||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
# Monotonic per-thread epoch used to prevent stale middleware
|
||||
# teardown from releasing a newer turn's lock.
|
||||
self._turn_epoch: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
|
|
@ -107,6 +110,14 @@ class _ThreadLockManager:
|
|||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||
self._turn_epoch[thread_id] = epoch
|
||||
return epoch
|
||||
|
||||
def current_turn_epoch(self, thread_id: str) -> int:
|
||||
return self._turn_epoch.get(thread_id, 0)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
|
|
@ -114,6 +125,10 @@ class _ThreadLockManager:
|
|||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||
# retry that already acquired the lock for the same thread.
|
||||
self.bump_turn_epoch(thread_id)
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
|
|
@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
super().__init__()
|
||||
self._require_thread_id = require_thread_id
|
||||
self.tools = []
|
||||
# Per-call locks owned by this middleware. We track them as
|
||||
# an instance attribute so ``aafter_agent`` knows which lock
|
||||
# to release.
|
||||
self._held_locks: dict[str, asyncio.Lock] = {}
|
||||
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||
# only releases when its epoch still matches the manager's current
|
||||
# epoch for the thread, preventing stale unlock races.
|
||||
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||
|
|
@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
await lock.acquire()
|
||||
self._held_locks[thread_id] = lock
|
||||
epoch = manager.bump_turn_epoch(thread_id)
|
||||
self._held_locks[thread_id] = (lock, epoch)
|
||||
# Reset the cancel event so this turn starts fresh
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
|
@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
return None
|
||||
lock = self._held_locks.pop(thread_id, None)
|
||||
if lock is not None and lock.locked():
|
||||
held = self._held_locks.pop(thread_id, None)
|
||||
if held is None:
|
||||
return None
|
||||
lock, held_epoch = held
|
||||
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||
# already advanced epoch). Do not touch current lock/cancel state.
|
||||
return None
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
# Always clear cancel event between turns so a stale signal
|
||||
# doesn't leak into the next request.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue