diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index 43b0af7d2..17f4e093e 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -37,6 +37,14 @@ from app.config import ( ) from app.db import User, create_db_and_tables, get_async_session from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError +from app.gateway.byo_long_poll import ( + start_byo_long_poll_supervisors, + stop_byo_long_poll_supervisors, +) +from app.gateway.inbox_worker import ( + start_gateway_inbox_worker, + stop_gateway_inbox_worker, +) from app.observability import metrics as ot_metrics from app.observability.bootstrap import init_otel, shutdown_otel from app.rate_limiter import get_real_client_ip, limiter @@ -597,12 +605,17 @@ async def lifespan(app: FastAPI): ) log_system_snapshot("startup_complete") + await start_gateway_inbox_worker() + await start_byo_long_poll_supervisors() - yield - - _stop_openrouter_background_refresh() - await close_checkpointer() - shutdown_otel() + try: + yield + finally: + await stop_byo_long_poll_supervisors() + await stop_gateway_inbox_worker() + _stop_openrouter_background_refresh() + await close_checkpointer() + shutdown_otel() def registration_allowed(): diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index a7739d6c4..89bf4c925 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -546,6 +546,9 @@ class Config: TELEGRAM_SHARED_BOT_USERNAME = os.getenv("TELEGRAM_SHARED_BOT_USERNAME") TELEGRAM_WEBHOOK_SECRET = os.getenv("TELEGRAM_WEBHOOK_SECRET") GATEWAY_BASE_URL = os.getenv("GATEWAY_BASE_URL", BACKEND_URL) + GATEWAY_BYO_LONGPOLL_ENABLED = ( + os.getenv("GATEWAY_BYO_LONGPOLL_ENABLED", "TRUE").upper() == "TRUE" + ) # Stripe checkout for pay-as-you-go page packs STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY") diff --git a/surfsense_backend/app/gateway/byo_long_poll.py b/surfsense_backend/app/gateway/byo_long_poll.py new file mode 100644 index 000000000..d02f19f95 --- /dev/null +++ b/surfsense_backend/app/gateway/byo_long_poll.py @@ -0,0 +1,94 @@ +"""FastAPI lifespan integration for self-hosted BYO Telegram long-polling.""" + +from __future__ import annotations + +import asyncio +import logging + +from sqlalchemy import select + +from app.config import config +from app.db import ExternalChatPlatform, ExternalChatAccount, async_session_maker +from app.gateway.accounts import account_token +from app.gateway.runner import _run_telegram_account + +logger = logging.getLogger(__name__) + +_tasks: set[asyncio.Task[None]] = set() +_shutdown_event: asyncio.Event | None = None + + +async def _sleep_or_shutdown(seconds: float) -> None: + if _shutdown_event is None: + await asyncio.sleep(seconds) + return + try: + await asyncio.wait_for(_shutdown_event.wait(), timeout=seconds) + except TimeoutError: + return + + +async def _byo_account_supervisor(account_id: int, token: str) -> None: + while _shutdown_event is None or not _shutdown_event.is_set(): + try: + await _run_telegram_account(account_id, token) + except asyncio.CancelledError: + raise + except Exception: + logger.exception( + "BYO Telegram long-poll failed account_id=%s; retrying in 30s", + account_id, + ) + await _sleep_or_shutdown(30) + + +async def start_byo_long_poll_supervisors() -> None: + """Start one BYO long-poll supervisor per active non-system Telegram account.""" + + global _shutdown_event + if not config.GATEWAY_BYO_LONGPOLL_ENABLED: + return + if _tasks: + return + + _shutdown_event = asyncio.Event() + async with async_session_maker() as session: + result = await session.execute( + select(ExternalChatAccount).where( + ExternalChatAccount.platform == ExternalChatPlatform.TELEGRAM, + ExternalChatAccount.is_system_account.is_(False), + ExternalChatAccount.suspended_at.is_(None), + ) + ) + accounts = list(result.scalars()) + + for account in accounts: + token = account_token(account) + if not token: + continue + task = asyncio.create_task( + _byo_account_supervisor(int(account.id), token), + name=f"gateway-byo-telegram-{account.id}", + ) + _tasks.add(task) + task.add_done_callback(_tasks.discard) + logger.info("Started BYO Telegram long-poll supervisor account_id=%s", account.id) + + +async def stop_byo_long_poll_supervisors() -> None: + """Cancel and await all BYO long-poll supervisors.""" + + global _shutdown_event + if _shutdown_event is not None: + _shutdown_event.set() + tasks = list(_tasks) + for task in tasks: + task.cancel() + if tasks: + try: + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) + except TimeoutError: + logger.warning("Timed out waiting for BYO Telegram long-poll supervisors to stop") + _tasks.clear() + _shutdown_event = None + diff --git a/surfsense_backend/app/gateway/inbox_worker.py b/surfsense_backend/app/gateway/inbox_worker.py new file mode 100644 index 000000000..e3ea7225c --- /dev/null +++ b/surfsense_backend/app/gateway/inbox_worker.py @@ -0,0 +1,55 @@ +"""FastAPI lifespan worker for gateway inbox processing.""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import suppress + +from app.gateway.inbox_processor import claim_next_inbound_event, process_inbound_event + +logger = logging.getLogger(__name__) + +_task: asyncio.Task[None] | None = None + + +async def _process_inbox_forever() -> None: + logger.info("Gateway inbox processor started in FastAPI process") + while True: + try: + inbox_id = await claim_next_inbound_event() + if inbox_id is None: + await asyncio.sleep(0.5) + continue + logger.info("Gateway processing inbox_id=%s", inbox_id) + await process_inbound_event(inbox_id) + logger.info("Gateway processed inbox_id=%s", inbox_id) + except asyncio.CancelledError: + raise + except RuntimeError as exc: + if str(exc) == "gateway_thread_busy": + logger.info("Gateway inbox_id busy; will retry from RECEIVED state") + else: + logger.exception("Gateway inbox processor failed one iteration") + await asyncio.sleep(1) + except Exception: + logger.exception("Gateway inbox processor failed one iteration") + await asyncio.sleep(1) + + +async def start_gateway_inbox_worker() -> None: + global _task + if _task is not None and not _task.done(): + return + _task = asyncio.create_task(_process_inbox_forever(), name="gateway-inbox-worker") + + +async def stop_gateway_inbox_worker() -> None: + global _task + if _task is None: + return + _task.cancel() + with suppress(TimeoutError, asyncio.CancelledError): + await asyncio.wait_for(_task, timeout=10) + _task = None + diff --git a/surfsense_backend/app/gateway/runner.py b/surfsense_backend/app/gateway/runner.py index 8ebd89253..83afc2353 100644 --- a/surfsense_backend/app/gateway/runner.py +++ b/surfsense_backend/app/gateway/runner.py @@ -1,18 +1,17 @@ -"""Long-lived messaging gateway runner.""" +"""Telegram BYO long-poll helper for FastAPI lifespan.""" from __future__ import annotations -import asyncio import hashlib import logging +import uuid -from sqlalchemy import select, text +from sqlalchemy import text -from app.db import GatewayPlatform, GatewayPlatformAccount, async_session_maker, engine -from app.gateway.accounts import account_token +from app.db import ExternalChatPlatform, ExternalChatAccount, async_session_maker, engine from app.gateway.inbox import persist_inbound_event, telegram_event_dedupe_key -from app.gateway.inbox_processor import claim_next_inbound_event, process_inbound_event from app.gateway.telegram.adapter import TelegramAdapter +from app.observability.metrics import record_gateway_byo_longpoll_running_delta logger = logging.getLogger(__name__) @@ -22,76 +21,45 @@ def _lock_key(token: str) -> int: return int.from_bytes(digest[:8], "big", signed=True) -class GatewayRunner: - async def run(self) -> None: - logger.info("Gateway runner started. Waiting for inbound events.") - tasks = [asyncio.create_task(self._process_inbox_forever())] - - async with async_session_maker() as session: - result = await session.execute( - select(GatewayPlatformAccount).where( - GatewayPlatformAccount.platform == GatewayPlatform.TELEGRAM, - GatewayPlatformAccount.is_system_account.is_(False), - GatewayPlatformAccount.suspended_at.is_(None), - ) - ) - accounts = list(result.scalars()) - - for account in accounts: - token = account_token(account) - if not token: - continue - logger.info("Starting Telegram long-poll loop for account_id=%s", account.id) - tasks.append(asyncio.create_task(self._run_telegram_account(account.id, token))) - - await asyncio.gather(*tasks) - - async def _process_inbox_forever(self) -> None: - logger.info("Gateway inbox processor started") - while True: - try: - inbox_id = await claim_next_inbound_event() - if inbox_id is None: - await asyncio.sleep(0.5) - continue - logger.info("Gateway processing inbox_id=%s", inbox_id) - await process_inbound_event(inbox_id) - logger.info("Gateway processed inbox_id=%s", inbox_id) - except Exception: - logger.exception("Gateway inbox processor failed one iteration") - await asyncio.sleep(1) - - async def _run_telegram_account(self, account_id: int, token: str) -> None: - async with engine.connect() as conn: - got_lock = await conn.scalar( - text("SELECT pg_try_advisory_lock(:key)"), - {"key": _lock_key(token)}, - ) - if not got_lock: - logger.warning("Another Telegram gateway runner is active; exiting") - return +async def _run_telegram_account(account_id: int, token: str) -> None: + async with engine.connect() as conn: + lock_key = _lock_key(token) + got_lock = await conn.scalar( + text("SELECT pg_try_advisory_lock(:key)"), + {"key": lock_key}, + ) + if not got_lock: + logger.warning("Another Telegram gateway runner is active; exiting") + return + record_gateway_byo_longpoll_running_delta(1, account_id=account_id) + try: adapter = TelegramAdapter(token) async with async_session_maker() as session: - account = await session.get(GatewayPlatformAccount, account_id) + account = await session.get(ExternalChatAccount, account_id) offset = None if account is not None: offset = int((account.cursor_state or {}).get("last_update_id", 0)) + 1 async for update in adapter.fetch_updates(offset=offset): + request_id = f"gateway_{uuid.uuid4().hex[:16]}" async with async_session_maker() as session: parsed = adapter.parse_inbound(update) inbox_id = await persist_inbound_event( session, account_id=account_id, - platform=GatewayPlatform.TELEGRAM, + platform=ExternalChatPlatform.TELEGRAM, event_dedupe_key=telegram_event_dedupe_key(update["update_id"]), external_event_id=str(update["update_id"]), external_message_id=parsed.external_message_id, event_kind=parsed.event_kind, raw_payload=update, + request_id=request_id, ) await session.commit() if inbox_id is not None: logger.debug("Persisted Telegram polling update inbox_id=%s", inbox_id) + finally: + record_gateway_byo_longpoll_running_delta(-1, account_id=account_id) + await conn.execute(text("SELECT pg_advisory_unlock(:key)"), {"key": lock_key}) diff --git a/surfsense_backend/gateway_runner.py b/surfsense_backend/gateway_runner.py deleted file mode 100644 index 27077ef48..000000000 --- a/surfsense_backend/gateway_runner.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Entrypoint for SERVICE_ROLE=gateway.""" - -from __future__ import annotations - -import asyncio -import logging - -from app.gateway.runner import GatewayRunner - -if __name__ == "__main__": - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s [%(name)s] %(message)s", - ) - asyncio.run(GatewayRunner().run()) - diff --git a/surfsense_backend/scripts/docker/entrypoint.sh b/surfsense_backend/scripts/docker/entrypoint.sh index 81db1ae84..0c1e66790 100644 --- a/surfsense_backend/scripts/docker/entrypoint.sh +++ b/surfsense_backend/scripts/docker/entrypoint.sh @@ -140,11 +140,11 @@ start_worker() { if [ -n "${CELERY_QUEUES}" ]; then QUEUE_ARGS="--queues=${CELERY_QUEUES}" else - # When no queues specified, consume from BOTH the default queue and - # the connectors queue. Without --queues, Celery only consumes from - # the default queue, leaving connector indexing tasks stuck. + # When no queues specified, consume from the default, connectors, and + # gateway maintenance queues. Without --queues, Celery only consumes + # from the default queue, leaving connector/gateway maintenance tasks stuck. DEFAULT_Q="${CELERY_TASK_DEFAULT_QUEUE:-surfsense}" - QUEUE_ARGS="--queues=${DEFAULT_Q},${DEFAULT_Q}.connectors" + QUEUE_ARGS="--queues=${DEFAULT_Q},${DEFAULT_Q}.connectors,${DEFAULT_Q}.gateway" fi echo "Starting Celery Worker (autoscale=${CELERY_MAX_WORKERS},${CELERY_MIN_WORKERS}, max-tasks-per-child=${CELERY_MAX_TASKS_PER_CHILD}, queues=${CELERY_QUEUES:-all})..." diff --git a/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py b/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py new file mode 100644 index 000000000..b8212ec9a --- /dev/null +++ b/surfsense_backend/tests/unit/gateway/test_byo_long_poll_lifespan.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import asyncio + +import pytest +import pytest_asyncio + +from app.gateway import byo_long_poll +from app.gateway import runner + + +class ScalarResult: + def __init__(self, rows): + self._rows = rows + + def scalars(self): + return self + + def __iter__(self): + return iter(self._rows) + + +class SessionContext: + def __init__(self, session): + self.session = session + + async def __aenter__(self): + return self.session + + async def __aexit__(self, exc_type, exc, tb): + return False + + +@pytest_asyncio.fixture(autouse=True) +async def cleanup_supervisors(): + yield + await byo_long_poll.stop_byo_long_poll_supervisors() + + +@pytest.mark.asyncio +async def test_start_byo_long_poll_noops_when_flag_off(monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_BYO_LONGPOLL_ENABLED", False) + + await byo_long_poll.start_byo_long_poll_supervisors() + + assert byo_long_poll._tasks == set() + + +@pytest.mark.asyncio +async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_BYO_LONGPOLL_ENABLED", True) + session = mocker.AsyncMock() + session.execute.return_value = ScalarResult([]) + monkeypatch.setattr( + byo_long_poll, + "async_session_maker", + lambda: SessionContext(session), + ) + + await byo_long_poll.start_byo_long_poll_supervisors() + + assert byo_long_poll._tasks == set() + + +@pytest.mark.asyncio +async def test_start_byo_long_poll_spawns_one_supervisor_per_account(mocker, monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_BYO_LONGPOLL_ENABLED", True) + accounts = [mocker.Mock(id=1), mocker.Mock(id=2)] + session = mocker.AsyncMock() + session.execute.return_value = ScalarResult(accounts) + monkeypatch.setattr( + byo_long_poll, + "async_session_maker", + lambda: SessionContext(session), + ) + monkeypatch.setattr(byo_long_poll, "account_token", lambda account: f"token-{account.id}") + + async def forever(_account_id: int, _token: str) -> None: + await asyncio.Event().wait() + + monkeypatch.setattr(byo_long_poll, "_byo_account_supervisor", forever) + + await byo_long_poll.start_byo_long_poll_supervisors() + + assert len(byo_long_poll._tasks) == 2 + + +@pytest.mark.asyncio +async def test_supervisor_retries_after_run_returns(mocker, monkeypatch): + byo_long_poll._shutdown_event = asyncio.Event() + run = mocker.AsyncMock(side_effect=[None, None]) + monkeypatch.setattr(byo_long_poll, "_run_telegram_account", run) + sleep_count = 0 + + async def fake_sleep(_seconds: float) -> None: + nonlocal sleep_count + sleep_count += 1 + if sleep_count >= 2: + assert byo_long_poll._shutdown_event is not None + byo_long_poll._shutdown_event.set() + + monkeypatch.setattr(byo_long_poll, "_sleep_or_shutdown", fake_sleep) + + await byo_long_poll._byo_account_supervisor(7, "token") + + assert run.await_count == 2 + + +@pytest.mark.asyncio +async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch): + monkeypatch.setattr(byo_long_poll.config, "GATEWAY_BYO_LONGPOLL_ENABLED", True) + session = mocker.AsyncMock() + session.execute.return_value = ScalarResult([mocker.Mock(id=1)]) + monkeypatch.setattr( + byo_long_poll, + "async_session_maker", + lambda: SessionContext(session), + ) + monkeypatch.setattr(byo_long_poll, "account_token", lambda _account: "token") + + async def forever(_account_id: int, _token: str) -> None: + await asyncio.Event().wait() + + monkeypatch.setattr(byo_long_poll, "_byo_account_supervisor", forever) + + await byo_long_poll.start_byo_long_poll_supervisors() + await byo_long_poll.stop_byo_long_poll_supervisors() + + assert byo_long_poll._tasks == set() + + +@pytest.mark.asyncio +async def test_run_telegram_account_persists_for_fastapi_inbox_worker(mocker, monkeypatch): + class ConnectionContext: + async def __aenter__(self): + conn = mocker.AsyncMock() + conn.scalar.return_value = True + return conn + + async def __aexit__(self, exc_type, exc, tb): + return False + + class EngineStub: + def connect(self): + return ConnectionContext() + + class AdapterStub: + def __init__(self, _token: str) -> None: + pass + + async def fetch_updates(self, *, offset: int | None): + yield {"update_id": 11, "message": {"message_id": 5}} + + def parse_inbound(self, update): + return mocker.Mock(external_message_id="5", event_kind="message") + + first_session = mocker.AsyncMock() + first_session.get.return_value = mocker.Mock(cursor_state={}) + second_session = mocker.AsyncMock() + contexts = iter([SessionContext(first_session), SessionContext(second_session)]) + monkeypatch.setattr(runner, "engine", EngineStub()) + monkeypatch.setattr(runner, "async_session_maker", lambda: next(contexts)) + monkeypatch.setattr(runner, "TelegramAdapter", AdapterStub) + persist = mocker.AsyncMock(return_value=42) + monkeypatch.setattr(runner, "persist_inbound_event", persist) + + await runner._run_telegram_account(123, "token") + + second_session.commit.assert_awaited_once() + persist.assert_awaited_once() + assert persist.await_args.kwargs["request_id"].startswith("gateway_") + diff --git a/surfsense_backend/tests/unit/gateway/test_inbox_worker.py b/surfsense_backend/tests/unit/gateway/test_inbox_worker.py new file mode 100644 index 000000000..8ecc4d86a --- /dev/null +++ b/surfsense_backend/tests/unit/gateway/test_inbox_worker.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from app.gateway import inbox_worker + + +@pytest.mark.asyncio +async def test_inbox_worker_claims_and_processes_in_fastapi_process(mocker, monkeypatch): + claim = mocker.AsyncMock(return_value=7) + process = mocker.AsyncMock(side_effect=asyncio.CancelledError) + monkeypatch.setattr(inbox_worker, "claim_next_inbound_event", claim) + monkeypatch.setattr(inbox_worker, "process_inbound_event", process) + + with pytest.raises(asyncio.CancelledError): + await inbox_worker._process_inbox_forever() + + claim.assert_awaited_once() + process.assert_awaited_once_with(7) + + +@pytest.mark.asyncio +async def test_start_stop_gateway_inbox_worker(mocker, monkeypatch): + started = asyncio.Event() + stopped = asyncio.Event() + + async def run_forever(): + started.set() + try: + await asyncio.Event().wait() + finally: + stopped.set() + + monkeypatch.setattr(inbox_worker, "_process_inbox_forever", run_forever) + inbox_worker._task = None + + await inbox_worker.start_gateway_inbox_worker() + await asyncio.wait_for(started.wait(), timeout=1) + await inbox_worker.stop_gateway_inbox_worker() + + assert stopped.is_set() + assert inbox_worker._task is None +