mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 13:22:41 +02:00
refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling
This commit is contained in:
parent
4056bd1d69
commit
af66fbf106
12 changed files with 671 additions and 81 deletions
|
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -58,6 +59,8 @@ class _ThreadLockManager:
|
|||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
|
|
@ -76,14 +79,45 @@ class _ThreadLockManager:
|
|||
def request_cancel(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
return False
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
event.set()
|
||||
now_ms = int(time.time() * 1000)
|
||||
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||
self._cancel_attempt_count[thread_id] = (
|
||||
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||
)
|
||||
return True
|
||||
|
||||
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
return bool(event and event.is_set())
|
||||
|
||||
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||
if not self.is_cancel_requested(thread_id):
|
||||
return None
|
||||
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||
return attempts, requested_at_ms
|
||||
|
||||
def reset(self, thread_id: str) -> None:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
This is intentionally idempotent and safe to call from outer stream
|
||||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
self.reset(thread_id)
|
||||
|
||||
|
||||
# Module-level singleton — process-local but reused across all agent
|
||||
|
|
@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
|
|||
|
||||
|
||||
def request_cancel(thread_id: str) -> bool:
|
||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
||||
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||
return manager.request_cancel(thread_id)
|
||||
|
||||
|
||||
def is_cancel_requested(thread_id: str) -> bool:
|
||||
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||
return manager.is_cancel_requested(thread_id)
|
||||
|
||||
|
||||
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||
return manager.cancel_state(thread_id)
|
||||
|
||||
|
||||
def reset_cancel(thread_id: str) -> None:
|
||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||
manager.reset(thread_id)
|
||||
|
||||
|
||||
def end_turn(thread_id: str) -> None:
|
||||
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||
manager.end_turn(thread_id)
|
||||
|
||||
|
||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Block concurrent prompts on the same thread.
|
||||
|
||||
|
|
@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"end_turn",
|
||||
"get_cancel_event",
|
||||
"get_cancel_state",
|
||||
"is_cancel_requested",
|
||||
"manager",
|
||||
"request_cancel",
|
||||
"reset_cancel",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue