mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
refactor(gateway): run inbox and BYO polling from FastAPI lifespan
This commit is contained in:
parent
72024353f9
commit
08bf3cc023
9 changed files with 415 additions and 81 deletions
|
|
@ -37,6 +37,14 @@ from app.config import (
|
||||||
)
|
)
|
||||||
from app.db import User, create_db_and_tables, get_async_session
|
from app.db import User, create_db_and_tables, get_async_session
|
||||||
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
|
||||||
|
from app.gateway.byo_long_poll import (
|
||||||
|
start_byo_long_poll_supervisors,
|
||||||
|
stop_byo_long_poll_supervisors,
|
||||||
|
)
|
||||||
|
from app.gateway.inbox_worker import (
|
||||||
|
start_gateway_inbox_worker,
|
||||||
|
stop_gateway_inbox_worker,
|
||||||
|
)
|
||||||
from app.observability import metrics as ot_metrics
|
from app.observability import metrics as ot_metrics
|
||||||
from app.observability.bootstrap import init_otel, shutdown_otel
|
from app.observability.bootstrap import init_otel, shutdown_otel
|
||||||
from app.rate_limiter import get_real_client_ip, limiter
|
from app.rate_limiter import get_real_client_ip, limiter
|
||||||
|
|
@ -597,12 +605,17 @@ async def lifespan(app: FastAPI):
|
||||||
)
|
)
|
||||||
|
|
||||||
log_system_snapshot("startup_complete")
|
log_system_snapshot("startup_complete")
|
||||||
|
await start_gateway_inbox_worker()
|
||||||
|
await start_byo_long_poll_supervisors()
|
||||||
|
|
||||||
yield
|
try:
|
||||||
|
yield
|
||||||
_stop_openrouter_background_refresh()
|
finally:
|
||||||
await close_checkpointer()
|
await stop_byo_long_poll_supervisors()
|
||||||
shutdown_otel()
|
await stop_gateway_inbox_worker()
|
||||||
|
_stop_openrouter_background_refresh()
|
||||||
|
await close_checkpointer()
|
||||||
|
shutdown_otel()
|
||||||
|
|
||||||
|
|
||||||
def registration_allowed():
|
def registration_allowed():
|
||||||
|
|
|
||||||
|
|
@ -546,6 +546,9 @@ class Config:
|
||||||
TELEGRAM_SHARED_BOT_USERNAME = os.getenv("TELEGRAM_SHARED_BOT_USERNAME")
|
TELEGRAM_SHARED_BOT_USERNAME = os.getenv("TELEGRAM_SHARED_BOT_USERNAME")
|
||||||
TELEGRAM_WEBHOOK_SECRET = os.getenv("TELEGRAM_WEBHOOK_SECRET")
|
TELEGRAM_WEBHOOK_SECRET = os.getenv("TELEGRAM_WEBHOOK_SECRET")
|
||||||
GATEWAY_BASE_URL = os.getenv("GATEWAY_BASE_URL", BACKEND_URL)
|
GATEWAY_BASE_URL = os.getenv("GATEWAY_BASE_URL", BACKEND_URL)
|
||||||
|
GATEWAY_BYO_LONGPOLL_ENABLED = (
|
||||||
|
os.getenv("GATEWAY_BYO_LONGPOLL_ENABLED", "TRUE").upper() == "TRUE"
|
||||||
|
)
|
||||||
|
|
||||||
# Stripe checkout for pay-as-you-go page packs
|
# Stripe checkout for pay-as-you-go page packs
|
||||||
STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY")
|
STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY")
|
||||||
|
|
|
||||||
94
surfsense_backend/app/gateway/byo_long_poll.py
Normal file
94
surfsense_backend/app/gateway/byo_long_poll.py
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""FastAPI lifespan integration for self-hosted BYO Telegram long-polling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import ExternalChatPlatform, ExternalChatAccount, async_session_maker
|
||||||
|
from app.gateway.accounts import account_token
|
||||||
|
from app.gateway.runner import _run_telegram_account
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_tasks: set[asyncio.Task[None]] = set()
|
||||||
|
_shutdown_event: asyncio.Event | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _sleep_or_shutdown(seconds: float) -> None:
|
||||||
|
if _shutdown_event is None:
|
||||||
|
await asyncio.sleep(seconds)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(_shutdown_event.wait(), timeout=seconds)
|
||||||
|
except TimeoutError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def _byo_account_supervisor(account_id: int, token: str) -> None:
|
||||||
|
while _shutdown_event is None or not _shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
await _run_telegram_account(account_id, token)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"BYO Telegram long-poll failed account_id=%s; retrying in 30s",
|
||||||
|
account_id,
|
||||||
|
)
|
||||||
|
await _sleep_or_shutdown(30)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_byo_long_poll_supervisors() -> None:
|
||||||
|
"""Start one BYO long-poll supervisor per active non-system Telegram account."""
|
||||||
|
|
||||||
|
global _shutdown_event
|
||||||
|
if not config.GATEWAY_BYO_LONGPOLL_ENABLED:
|
||||||
|
return
|
||||||
|
if _tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
_shutdown_event = asyncio.Event()
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(ExternalChatAccount).where(
|
||||||
|
ExternalChatAccount.platform == ExternalChatPlatform.TELEGRAM,
|
||||||
|
ExternalChatAccount.is_system_account.is_(False),
|
||||||
|
ExternalChatAccount.suspended_at.is_(None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
accounts = list(result.scalars())
|
||||||
|
|
||||||
|
for account in accounts:
|
||||||
|
token = account_token(account)
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_byo_account_supervisor(int(account.id), token),
|
||||||
|
name=f"gateway-byo-telegram-{account.id}",
|
||||||
|
)
|
||||||
|
_tasks.add(task)
|
||||||
|
task.add_done_callback(_tasks.discard)
|
||||||
|
logger.info("Started BYO Telegram long-poll supervisor account_id=%s", account.id)
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_byo_long_poll_supervisors() -> None:
|
||||||
|
"""Cancel and await all BYO long-poll supervisors."""
|
||||||
|
|
||||||
|
global _shutdown_event
|
||||||
|
if _shutdown_event is not None:
|
||||||
|
_shutdown_event.set()
|
||||||
|
tasks = list(_tasks)
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
if tasks:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning("Timed out waiting for BYO Telegram long-poll supervisors to stop")
|
||||||
|
_tasks.clear()
|
||||||
|
_shutdown_event = None
|
||||||
|
|
||||||
55
surfsense_backend/app/gateway/inbox_worker.py
Normal file
55
surfsense_backend/app/gateway/inbox_worker.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
"""FastAPI lifespan worker for gateway inbox processing."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
from app.gateway.inbox_processor import claim_next_inbound_event, process_inbound_event
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_inbox_forever() -> None:
|
||||||
|
logger.info("Gateway inbox processor started in FastAPI process")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
inbox_id = await claim_next_inbound_event()
|
||||||
|
if inbox_id is None:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
continue
|
||||||
|
logger.info("Gateway processing inbox_id=%s", inbox_id)
|
||||||
|
await process_inbound_event(inbox_id)
|
||||||
|
logger.info("Gateway processed inbox_id=%s", inbox_id)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if str(exc) == "gateway_thread_busy":
|
||||||
|
logger.info("Gateway inbox_id busy; will retry from RECEIVED state")
|
||||||
|
else:
|
||||||
|
logger.exception("Gateway inbox processor failed one iteration")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Gateway inbox processor failed one iteration")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_gateway_inbox_worker() -> None:
|
||||||
|
global _task
|
||||||
|
if _task is not None and not _task.done():
|
||||||
|
return
|
||||||
|
_task = asyncio.create_task(_process_inbox_forever(), name="gateway-inbox-worker")
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_gateway_inbox_worker() -> None:
|
||||||
|
global _task
|
||||||
|
if _task is None:
|
||||||
|
return
|
||||||
|
_task.cancel()
|
||||||
|
with suppress(TimeoutError, asyncio.CancelledError):
|
||||||
|
await asyncio.wait_for(_task, timeout=10)
|
||||||
|
_task = None
|
||||||
|
|
||||||
|
|
@ -1,18 +1,17 @@
|
||||||
"""Long-lived messaging gateway runner."""
|
"""Telegram BYO long-poll helper for FastAPI lifespan."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
from sqlalchemy import select, text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from app.db import GatewayPlatform, GatewayPlatformAccount, async_session_maker, engine
|
from app.db import ExternalChatPlatform, ExternalChatAccount, async_session_maker, engine
|
||||||
from app.gateway.accounts import account_token
|
|
||||||
from app.gateway.inbox import persist_inbound_event, telegram_event_dedupe_key
|
from app.gateway.inbox import persist_inbound_event, telegram_event_dedupe_key
|
||||||
from app.gateway.inbox_processor import claim_next_inbound_event, process_inbound_event
|
|
||||||
from app.gateway.telegram.adapter import TelegramAdapter
|
from app.gateway.telegram.adapter import TelegramAdapter
|
||||||
|
from app.observability.metrics import record_gateway_byo_longpoll_running_delta
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -22,76 +21,45 @@ def _lock_key(token: str) -> int:
|
||||||
return int.from_bytes(digest[:8], "big", signed=True)
|
return int.from_bytes(digest[:8], "big", signed=True)
|
||||||
|
|
||||||
|
|
||||||
class GatewayRunner:
|
async def _run_telegram_account(account_id: int, token: str) -> None:
|
||||||
async def run(self) -> None:
|
async with engine.connect() as conn:
|
||||||
logger.info("Gateway runner started. Waiting for inbound events.")
|
lock_key = _lock_key(token)
|
||||||
tasks = [asyncio.create_task(self._process_inbox_forever())]
|
got_lock = await conn.scalar(
|
||||||
|
text("SELECT pg_try_advisory_lock(:key)"),
|
||||||
async with async_session_maker() as session:
|
{"key": lock_key},
|
||||||
result = await session.execute(
|
)
|
||||||
select(GatewayPlatformAccount).where(
|
if not got_lock:
|
||||||
GatewayPlatformAccount.platform == GatewayPlatform.TELEGRAM,
|
logger.warning("Another Telegram gateway runner is active; exiting")
|
||||||
GatewayPlatformAccount.is_system_account.is_(False),
|
return
|
||||||
GatewayPlatformAccount.suspended_at.is_(None),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
accounts = list(result.scalars())
|
|
||||||
|
|
||||||
for account in accounts:
|
|
||||||
token = account_token(account)
|
|
||||||
if not token:
|
|
||||||
continue
|
|
||||||
logger.info("Starting Telegram long-poll loop for account_id=%s", account.id)
|
|
||||||
tasks.append(asyncio.create_task(self._run_telegram_account(account.id, token)))
|
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
async def _process_inbox_forever(self) -> None:
|
|
||||||
logger.info("Gateway inbox processor started")
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
inbox_id = await claim_next_inbound_event()
|
|
||||||
if inbox_id is None:
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
continue
|
|
||||||
logger.info("Gateway processing inbox_id=%s", inbox_id)
|
|
||||||
await process_inbound_event(inbox_id)
|
|
||||||
logger.info("Gateway processed inbox_id=%s", inbox_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Gateway inbox processor failed one iteration")
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
async def _run_telegram_account(self, account_id: int, token: str) -> None:
|
|
||||||
async with engine.connect() as conn:
|
|
||||||
got_lock = await conn.scalar(
|
|
||||||
text("SELECT pg_try_advisory_lock(:key)"),
|
|
||||||
{"key": _lock_key(token)},
|
|
||||||
)
|
|
||||||
if not got_lock:
|
|
||||||
logger.warning("Another Telegram gateway runner is active; exiting")
|
|
||||||
return
|
|
||||||
|
|
||||||
|
record_gateway_byo_longpoll_running_delta(1, account_id=account_id)
|
||||||
|
try:
|
||||||
adapter = TelegramAdapter(token)
|
adapter = TelegramAdapter(token)
|
||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
account = await session.get(GatewayPlatformAccount, account_id)
|
account = await session.get(ExternalChatAccount, account_id)
|
||||||
offset = None
|
offset = None
|
||||||
if account is not None:
|
if account is not None:
|
||||||
offset = int((account.cursor_state or {}).get("last_update_id", 0)) + 1
|
offset = int((account.cursor_state or {}).get("last_update_id", 0)) + 1
|
||||||
|
|
||||||
async for update in adapter.fetch_updates(offset=offset):
|
async for update in adapter.fetch_updates(offset=offset):
|
||||||
|
request_id = f"gateway_{uuid.uuid4().hex[:16]}"
|
||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
parsed = adapter.parse_inbound(update)
|
parsed = adapter.parse_inbound(update)
|
||||||
inbox_id = await persist_inbound_event(
|
inbox_id = await persist_inbound_event(
|
||||||
session,
|
session,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
platform=GatewayPlatform.TELEGRAM,
|
platform=ExternalChatPlatform.TELEGRAM,
|
||||||
event_dedupe_key=telegram_event_dedupe_key(update["update_id"]),
|
event_dedupe_key=telegram_event_dedupe_key(update["update_id"]),
|
||||||
external_event_id=str(update["update_id"]),
|
external_event_id=str(update["update_id"]),
|
||||||
external_message_id=parsed.external_message_id,
|
external_message_id=parsed.external_message_id,
|
||||||
event_kind=parsed.event_kind,
|
event_kind=parsed.event_kind,
|
||||||
raw_payload=update,
|
raw_payload=update,
|
||||||
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
if inbox_id is not None:
|
if inbox_id is not None:
|
||||||
logger.debug("Persisted Telegram polling update inbox_id=%s", inbox_id)
|
logger.debug("Persisted Telegram polling update inbox_id=%s", inbox_id)
|
||||||
|
finally:
|
||||||
|
record_gateway_byo_longpoll_running_delta(-1, account_id=account_id)
|
||||||
|
await conn.execute(text("SELECT pg_advisory_unlock(:key)"), {"key": lock_key})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
"""Entrypoint for SERVICE_ROLE=gateway."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from app.gateway.runner import GatewayRunner
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
|
||||||
)
|
|
||||||
asyncio.run(GatewayRunner().run())
|
|
||||||
|
|
||||||
|
|
@ -140,11 +140,11 @@ start_worker() {
|
||||||
if [ -n "${CELERY_QUEUES}" ]; then
|
if [ -n "${CELERY_QUEUES}" ]; then
|
||||||
QUEUE_ARGS="--queues=${CELERY_QUEUES}"
|
QUEUE_ARGS="--queues=${CELERY_QUEUES}"
|
||||||
else
|
else
|
||||||
# When no queues specified, consume from BOTH the default queue and
|
# When no queues specified, consume from the default, connectors, and
|
||||||
# the connectors queue. Without --queues, Celery only consumes from
|
# gateway maintenance queues. Without --queues, Celery only consumes
|
||||||
# the default queue, leaving connector indexing tasks stuck.
|
# from the default queue, leaving connector/gateway maintenance tasks stuck.
|
||||||
DEFAULT_Q="${CELERY_TASK_DEFAULT_QUEUE:-surfsense}"
|
DEFAULT_Q="${CELERY_TASK_DEFAULT_QUEUE:-surfsense}"
|
||||||
QUEUE_ARGS="--queues=${DEFAULT_Q},${DEFAULT_Q}.connectors"
|
QUEUE_ARGS="--queues=${DEFAULT_Q},${DEFAULT_Q}.connectors,${DEFAULT_Q}.gateway"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Starting Celery Worker (autoscale=${CELERY_MAX_WORKERS},${CELERY_MIN_WORKERS}, max-tasks-per-child=${CELERY_MAX_TASKS_PER_CHILD}, queues=${CELERY_QUEUES:-all})..."
|
echo "Starting Celery Worker (autoscale=${CELERY_MAX_WORKERS},${CELERY_MIN_WORKERS}, max-tasks-per-child=${CELERY_MAX_TASKS_PER_CHILD}, queues=${CELERY_QUEUES:-all})..."
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,172 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.gateway import byo_long_poll
|
||||||
|
from app.gateway import runner
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarResult:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
def scalars(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._rows)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionContext:
|
||||||
|
def __init__(self, session):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def cleanup_supervisors():
|
||||||
|
yield
|
||||||
|
await byo_long_poll.stop_byo_long_poll_supervisors()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_byo_long_poll_noops_when_flag_off(monkeypatch):
|
||||||
|
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_BYO_LONGPOLL_ENABLED", False)
|
||||||
|
|
||||||
|
await byo_long_poll.start_byo_long_poll_supervisors()
|
||||||
|
|
||||||
|
assert byo_long_poll._tasks == set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_byo_long_poll_noops_when_no_byo_accounts(mocker, monkeypatch):
|
||||||
|
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_BYO_LONGPOLL_ENABLED", True)
|
||||||
|
session = mocker.AsyncMock()
|
||||||
|
session.execute.return_value = ScalarResult([])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
byo_long_poll,
|
||||||
|
"async_session_maker",
|
||||||
|
lambda: SessionContext(session),
|
||||||
|
)
|
||||||
|
|
||||||
|
await byo_long_poll.start_byo_long_poll_supervisors()
|
||||||
|
|
||||||
|
assert byo_long_poll._tasks == set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_byo_long_poll_spawns_one_supervisor_per_account(mocker, monkeypatch):
|
||||||
|
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_BYO_LONGPOLL_ENABLED", True)
|
||||||
|
accounts = [mocker.Mock(id=1), mocker.Mock(id=2)]
|
||||||
|
session = mocker.AsyncMock()
|
||||||
|
session.execute.return_value = ScalarResult(accounts)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
byo_long_poll,
|
||||||
|
"async_session_maker",
|
||||||
|
lambda: SessionContext(session),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(byo_long_poll, "account_token", lambda account: f"token-{account.id}")
|
||||||
|
|
||||||
|
async def forever(_account_id: int, _token: str) -> None:
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
monkeypatch.setattr(byo_long_poll, "_byo_account_supervisor", forever)
|
||||||
|
|
||||||
|
await byo_long_poll.start_byo_long_poll_supervisors()
|
||||||
|
|
||||||
|
assert len(byo_long_poll._tasks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_supervisor_retries_after_run_returns(mocker, monkeypatch):
|
||||||
|
byo_long_poll._shutdown_event = asyncio.Event()
|
||||||
|
run = mocker.AsyncMock(side_effect=[None, None])
|
||||||
|
monkeypatch.setattr(byo_long_poll, "_run_telegram_account", run)
|
||||||
|
sleep_count = 0
|
||||||
|
|
||||||
|
async def fake_sleep(_seconds: float) -> None:
|
||||||
|
nonlocal sleep_count
|
||||||
|
sleep_count += 1
|
||||||
|
if sleep_count >= 2:
|
||||||
|
assert byo_long_poll._shutdown_event is not None
|
||||||
|
byo_long_poll._shutdown_event.set()
|
||||||
|
|
||||||
|
monkeypatch.setattr(byo_long_poll, "_sleep_or_shutdown", fake_sleep)
|
||||||
|
|
||||||
|
await byo_long_poll._byo_account_supervisor(7, "token")
|
||||||
|
|
||||||
|
assert run.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_cancels_running_supervisors(mocker, monkeypatch):
|
||||||
|
monkeypatch.setattr(byo_long_poll.config, "GATEWAY_BYO_LONGPOLL_ENABLED", True)
|
||||||
|
session = mocker.AsyncMock()
|
||||||
|
session.execute.return_value = ScalarResult([mocker.Mock(id=1)])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
byo_long_poll,
|
||||||
|
"async_session_maker",
|
||||||
|
lambda: SessionContext(session),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(byo_long_poll, "account_token", lambda _account: "token")
|
||||||
|
|
||||||
|
async def forever(_account_id: int, _token: str) -> None:
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
monkeypatch.setattr(byo_long_poll, "_byo_account_supervisor", forever)
|
||||||
|
|
||||||
|
await byo_long_poll.start_byo_long_poll_supervisors()
|
||||||
|
await byo_long_poll.stop_byo_long_poll_supervisors()
|
||||||
|
|
||||||
|
assert byo_long_poll._tasks == set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_telegram_account_persists_for_fastapi_inbox_worker(mocker, monkeypatch):
|
||||||
|
class ConnectionContext:
|
||||||
|
async def __aenter__(self):
|
||||||
|
conn = mocker.AsyncMock()
|
||||||
|
conn.scalar.return_value = True
|
||||||
|
return conn
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
class EngineStub:
|
||||||
|
def connect(self):
|
||||||
|
return ConnectionContext()
|
||||||
|
|
||||||
|
class AdapterStub:
|
||||||
|
def __init__(self, _token: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def fetch_updates(self, *, offset: int | None):
|
||||||
|
yield {"update_id": 11, "message": {"message_id": 5}}
|
||||||
|
|
||||||
|
def parse_inbound(self, update):
|
||||||
|
return mocker.Mock(external_message_id="5", event_kind="message")
|
||||||
|
|
||||||
|
first_session = mocker.AsyncMock()
|
||||||
|
first_session.get.return_value = mocker.Mock(cursor_state={})
|
||||||
|
second_session = mocker.AsyncMock()
|
||||||
|
contexts = iter([SessionContext(first_session), SessionContext(second_session)])
|
||||||
|
monkeypatch.setattr(runner, "engine", EngineStub())
|
||||||
|
monkeypatch.setattr(runner, "async_session_maker", lambda: next(contexts))
|
||||||
|
monkeypatch.setattr(runner, "TelegramAdapter", AdapterStub)
|
||||||
|
persist = mocker.AsyncMock(return_value=42)
|
||||||
|
monkeypatch.setattr(runner, "persist_inbound_event", persist)
|
||||||
|
|
||||||
|
await runner._run_telegram_account(123, "token")
|
||||||
|
|
||||||
|
second_session.commit.assert_awaited_once()
|
||||||
|
persist.assert_awaited_once()
|
||||||
|
assert persist.await_args.kwargs["request_id"].startswith("gateway_")
|
||||||
|
|
||||||
45
surfsense_backend/tests/unit/gateway/test_inbox_worker.py
Normal file
45
surfsense_backend/tests/unit/gateway/test_inbox_worker.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.gateway import inbox_worker
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inbox_worker_claims_and_processes_in_fastapi_process(mocker, monkeypatch):
|
||||||
|
claim = mocker.AsyncMock(return_value=7)
|
||||||
|
process = mocker.AsyncMock(side_effect=asyncio.CancelledError)
|
||||||
|
monkeypatch.setattr(inbox_worker, "claim_next_inbound_event", claim)
|
||||||
|
monkeypatch.setattr(inbox_worker, "process_inbound_event", process)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await inbox_worker._process_inbox_forever()
|
||||||
|
|
||||||
|
claim.assert_awaited_once()
|
||||||
|
process.assert_awaited_once_with(7)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_stop_gateway_inbox_worker(mocker, monkeypatch):
|
||||||
|
started = asyncio.Event()
|
||||||
|
stopped = asyncio.Event()
|
||||||
|
|
||||||
|
async def run_forever():
|
||||||
|
started.set()
|
||||||
|
try:
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
finally:
|
||||||
|
stopped.set()
|
||||||
|
|
||||||
|
monkeypatch.setattr(inbox_worker, "_process_inbox_forever", run_forever)
|
||||||
|
inbox_worker._task = None
|
||||||
|
|
||||||
|
await inbox_worker.start_gateway_inbox_worker()
|
||||||
|
await asyncio.wait_for(started.wait(), timeout=1)
|
||||||
|
await inbox_worker.stop_gateway_inbox_worker()
|
||||||
|
|
||||||
|
assert stopped.is_set()
|
||||||
|
assert inbox_worker._task is None
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue