diff --git a/surfsense_backend/app/gateway/agent_invoke.py b/surfsense_backend/app/gateway/agent_invoke.py new file mode 100644 index 000000000..b0cccddaa --- /dev/null +++ b/surfsense_backend/app/gateway/agent_invoke.py @@ -0,0 +1,80 @@ +"""Invoke SurfSense chat agent for gateway channels.""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncIterator + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import GatewayConversationBinding +from app.gateway.auth_invariant import assert_authorization_invariant +from app.gateway.base.translator import GatewayStreamEvent +from app.gateway.bindings import get_or_create_thread_for_binding +from app.gateway.hitl_filter import DEFAULT_HITL_TOOL_NAMES +from app.gateway.telegram.translator import TelegramStreamTranslator +from app.gateway.thread_lock import acquire_thread_lock, release_thread_lock +from app.observability.metrics import record_gateway_turn_latency +from app.tasks.chat.stream_new_chat import stream_new_chat + +logger = logging.getLogger(__name__) + + +async def _events_from_sse(chunks: AsyncIterator[str]) -> AsyncIterator[GatewayStreamEvent]: + async for chunk in chunks: + for raw_line in chunk.splitlines(): + line = raw_line.strip() + if not line.startswith("data:"): + continue + payload = line.removeprefix("data:").strip() + if payload == "[DONE]": + yield GatewayStreamEvent(type="done") + continue + try: + data = json.loads(payload) + except json.JSONDecodeError: + continue + event_type = str(data.get("type") or "") + if event_type == "text-delta": + yield GatewayStreamEvent(type="text-delta", data={"delta": data.get("delta", "")}) + elif event_type == "text-end": + yield GatewayStreamEvent(type="text-end", data=data) + elif event_type == "finish": + yield GatewayStreamEvent(type="finish", data=data) + elif event_type == "data-interrupt-request": + yield GatewayStreamEvent(type="data-interrupt-request", data=data) + + +async def call_agent_for_gateway( + *, + session: AsyncSession, + binding: GatewayConversationBinding, + user_text: str, + translator: TelegramStreamTranslator, + request_id: str | None = None, +) -> None: + user = await assert_authorization_invariant(session, binding) + thread = await get_or_create_thread_for_binding(session, binding) + await session.commit() + + if not acquire_thread_lock(thread.id): + raise RuntimeError("gateway_thread_busy") + + try: + stream = stream_new_chat( + user_query=user_text, + search_space_id=binding.search_space_id, + chat_id=thread.id, + user_id=str(user.id), + needs_history_bootstrap=thread.needs_history_bootstrap, + thread_visibility=thread.visibility, + current_user_display_name=user.display_name or "A team member", + disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES), + request_id=request_id or "gateway", + ) + await translator.translate(_events_from_sse(stream)) + record_gateway_turn_latency(0, platform="telegram") + finally: + release_thread_lock(thread.id) + diff --git a/surfsense_backend/app/gateway/inbox_processor.py b/surfsense_backend/app/gateway/inbox_processor.py new file mode 100644 index 000000000..3e3f962b7 --- /dev/null +++ b/surfsense_backend/app/gateway/inbox_processor.py @@ -0,0 +1,262 @@ +"""Long-lived gateway inbox processing. + +This module owns the agent-turn execution path for messaging gateways. It is +intentionally independent of Celery so LangGraph, async Postgres, Redis, and +Telegram clients all run on one stable event loop in ``GatewayRunner``. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from datetime import UTC, datetime + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from app.config import config +from app.db import ( + GatewayBindingState, + GatewayConversationBinding, + GatewayEventStatus, + GatewayInboundEvent, + GatewayPeerKind, + GatewayPlatformAccount, + NewChatThread, + async_session_maker, +) +from app.gateway.accounts import account_token +from app.gateway.agent_invoke import call_agent_for_gateway +from app.gateway.bindings import get_or_create_thread_for_binding +from app.gateway.telegram.adapter import TelegramAdapter +from app.gateway.telegram.commands import ( + command_name, + handle_help_command, + handle_start_command, + send_unbound_onboarding, +) +from app.gateway.telegram.translator import TelegramStreamTranslator +from app.observability.metrics import record_gateway_inbox_processed + +logger = logging.getLogger(__name__) + +SessionMaker = async_sessionmaker[AsyncSession] | Callable[[], AsyncSession] + + +def _dashboard_url() -> str: + return config.NEXT_FRONTEND_URL or "/dashboard" + + +async def claim_next_inbound_event( + session_maker: SessionMaker = async_session_maker, +) -> int | None: + """Claim the oldest received inbox event for processing.""" + + async with session_maker() as session: + result = await session.execute( + select(GatewayInboundEvent) + .where(GatewayInboundEvent.status == GatewayEventStatus.RECEIVED) + .order_by(GatewayInboundEvent.received_at.asc()) + .with_for_update(skip_locked=True) + .limit(1) + ) + event = result.scalars().first() + if event is None: + return None + event.status = GatewayEventStatus.PROCESSING + event.attempt_count += 1 + await session.commit() + return int(event.id) + + +async def process_inbound_event( + inbox_id: int, + session_maker: SessionMaker = async_session_maker, +) -> None: + """Process one gateway inbox row and mark its terminal status.""" + + async with session_maker() as session: + result = await session.execute( + select(GatewayInboundEvent) + .where(GatewayInboundEvent.id == inbox_id) + .with_for_update(skip_locked=True) + ) + event = result.scalars().first() + if event is None or event.status in { + GatewayEventStatus.PROCESSED, + GatewayEventStatus.IGNORED, + }: + return + if event.status == GatewayEventStatus.RECEIVED: + event.status = GatewayEventStatus.PROCESSING + event.attempt_count += 1 + await session.commit() + + try: + await _dispatch_inbound_event(inbox_id, session_maker) + except RuntimeError as exc: + if str(exc) == "gateway_thread_busy": + async with session_maker() as session: + await session.execute( + update(GatewayInboundEvent) + .where(GatewayInboundEvent.id == inbox_id) + .values( + status=GatewayEventStatus.RECEIVED, + last_error="gateway_thread_busy", + ) + ) + await session.commit() + return + await _mark_failed(inbox_id, str(exc), session_maker) + raise + except Exception as exc: + await _mark_failed(inbox_id, str(exc), session_maker) + raise + + async with session_maker() as session: + event = await session.get(GatewayInboundEvent, inbox_id) + if event is not None and event.status == GatewayEventStatus.PROCESSING: + event.status = GatewayEventStatus.PROCESSED + event.processed_at = datetime.now(UTC) + await session.commit() + record_gateway_inbox_processed(platform=event.platform.value, status="processed") + + +async def _mark_failed( + inbox_id: int, + error: str, + session_maker: SessionMaker, +) -> None: + async with session_maker() as session: + await session.execute( + update(GatewayInboundEvent) + .where(GatewayInboundEvent.id == inbox_id) + .values(status=GatewayEventStatus.FAILED, last_error=error) + ) + await session.commit() + + +async def _dispatch_inbound_event( + inbox_id: int, + session_maker: SessionMaker, +) -> None: + async with session_maker() as session: + event = await session.get(GatewayInboundEvent, inbox_id) + if event is None: + return + account = await session.get(GatewayPlatformAccount, event.account_id) + if account is None: + event.status = GatewayEventStatus.IGNORED + event.last_error = "account_missing" + await session.commit() + return + + token = account_token(account) + if not token: + event.status = GatewayEventStatus.FAILED + event.last_error = "missing_telegram_token" + await session.commit() + return + + adapter = TelegramAdapter(token) + parsed = adapter.parse_inbound(event.raw_payload or {}) + if parsed.external_peer_id is None: + event.status = GatewayEventStatus.IGNORED + event.last_error = "missing_external_peer_id" + await session.commit() + return + + _update_account_cursor(account, parsed.metadata.get("update_id")) + + result = await session.execute( + select(GatewayConversationBinding).where( + GatewayConversationBinding.account_id == account.id, + GatewayConversationBinding.external_peer_id == parsed.external_peer_id, + GatewayConversationBinding.state.in_( + [GatewayBindingState.BOUND, GatewayBindingState.SUSPENDED] + ), + ) + ) + binding = result.scalars().first() + + if parsed.external_peer_kind != GatewayPeerKind.DIRECT.value: + await adapter.leave_chat(external_peer_id=parsed.external_peer_id) + event.status = GatewayEventStatus.IGNORED + event.last_error = "group_rejected" + await session.commit() + return + + cmd = command_name(parsed.text) + if cmd == "/start": + handled = await handle_start_command( + session=session, adapter=adapter, event=parsed + ) + await session.commit() + if handled: + return + + if binding is None: + await send_unbound_onboarding( + adapter=adapter, + event=parsed, + dashboard_url=_dashboard_url(), + ) + event.status = GatewayEventStatus.IGNORED + event.last_error = "unbound_chat" + await session.commit() + return + + event.binding_id = binding.id + + if cmd == "/help": + await handle_help_command(adapter=adapter, event=parsed) + event.status = GatewayEventStatus.PROCESSED + await session.commit() + return + if cmd == "/new": + binding.active_thread_id = None + await adapter.send_message( + external_peer_id=parsed.external_peer_id, + text="Started a new SurfSense conversation.", + ) + event.status = GatewayEventStatus.PROCESSED + await session.commit() + return + + if not parsed.text: + event.status = GatewayEventStatus.IGNORED + event.last_error = "empty_message" + await session.commit() + return + + thread = await get_or_create_thread_for_binding(session, binding) + await session.commit() + + translator = TelegramStreamTranslator( + adapter=adapter, + external_peer_id=parsed.external_peer_id, + ) + await call_agent_for_gateway( + session=session, + binding=binding, + user_text=parsed.text, + translator=translator, + request_id=f"gateway:{inbox_id}", + ) + + thread = await session.get(NewChatThread, thread.id) + if thread is not None: + thread.source = "telegram" + await session.commit() + + +def _update_account_cursor(account: GatewayPlatformAccount, update_id: object) -> None: + if update_id is None: + return + account.cursor_state = { + **(account.cursor_state or {}), + "last_update_id": max( + int((account.cursor_state or {}).get("last_update_id", 0)), + int(update_id), + ), + } diff --git a/surfsense_backend/app/gateway/ratelimit.py b/surfsense_backend/app/gateway/ratelimit.py new file mode 100644 index 000000000..fbcbd16b8 --- /dev/null +++ b/surfsense_backend/app/gateway/ratelimit.py @@ -0,0 +1,136 @@ +"""Redis token-bucket rate limiter for gateway outbound traffic.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass + +import redis.asyncio as aioredis + +from app.config import config +from app.observability.metrics import record_gateway_redis_fallback + +logger = logging.getLogger(__name__) + +_TOKEN_BUCKET_LUA = """ +local capacity = tonumber(ARGV[1]) +local refill_rate = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local consume = tonumber(ARGV[4]) + +local bucket = redis.call('HMGET', KEYS[1], 'tokens', 'last_refill') +local tokens = tonumber(bucket[1]) or capacity +local last_refill = tonumber(bucket[2]) or now + +local elapsed = math.max(0, now - last_refill) +tokens = math.min(capacity, tokens + (elapsed * refill_rate)) + +if tokens >= consume then + tokens = tokens - consume + redis.call('HMSET', KEYS[1], 'tokens', tokens, 'last_refill', now) + redis.call('EXPIRE', KEYS[1], 3600) + return 0 +else + redis.call('HMSET', KEYS[1], 'tokens', tokens, 'last_refill', now) + redis.call('EXPIRE', KEYS[1], 3600) + local needed = consume - tokens + return math.ceil((needed / refill_rate) * 1000) +end +""" + +_redis_client: aioredis.Redis | None = None + + +@dataclass +class _MemoryBucket: + tokens: float + last_refill: float + + +_memory_buckets: dict[str, _MemoryBucket] = {} +_memory_lock = asyncio.Lock() + + +def _redis() -> aioredis.Redis: + global _redis_client + if _redis_client is None: + _redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True) + return _redis_client + + +async def _memory_fallback_acquire( + scope: str, + capacity: int, + refill_per_sec: float, + consume: float, +) -> int: + now = time.time() + async with _memory_lock: + bucket = _memory_buckets.get(scope) + if bucket is None: + bucket = _MemoryBucket(tokens=float(capacity), last_refill=now) + _memory_buckets[scope] = bucket + + elapsed = max(0.0, now - bucket.last_refill) + bucket.tokens = min(float(capacity), bucket.tokens + elapsed * refill_per_sec) + bucket.last_refill = now + + if bucket.tokens >= consume: + bucket.tokens -= consume + return 0 + + needed = consume - bucket.tokens + return int((needed / refill_per_sec) * 1000) if refill_per_sec > 0 else 1000 + + +async def acquire_token( + scope: str, + *, + capacity: int, + refill_per_sec: float, + consume: float = 1.0, +) -> int: + """Return 0 if allowed, otherwise milliseconds to wait. + + Redis is the primary coordination mechanism. If Redis is unavailable, + fall back to per-process memory so the gateway degrades instead of failing + closed during a short Redis outage. + """ + + redis_key = f"gateway:bucket:{scope}" + try: + wait_ms = await _redis().eval( + _TOKEN_BUCKET_LUA, + 1, + redis_key, + capacity, + refill_per_sec, + time.time(), + consume, + ) + return int(wait_ms) + except (aioredis.RedisError, OSError) as exc: + logger.warning("Redis rate limiter unavailable; using memory fallback: %s", exc) + record_gateway_redis_fallback() + return await _memory_fallback_acquire(scope, capacity, refill_per_sec, consume) + + +async def wait_for_token( + scope: str, + *, + capacity: int, + refill_per_sec: float, + consume: float = 1.0, +) -> int: + wait_ms = await acquire_token( + scope, + capacity=capacity, + refill_per_sec=refill_per_sec, + consume=consume, + ) + if wait_ms > 0: + await asyncio.sleep(wait_ms / 1000) + return wait_ms + diff --git a/surfsense_backend/app/gateway/thread_lock.py b/surfsense_backend/app/gateway/thread_lock.py new file mode 100644 index 000000000..82733bb69 --- /dev/null +++ b/surfsense_backend/app/gateway/thread_lock.py @@ -0,0 +1,40 @@ +"""Redis-backed distributed locks for gateway conversation turns.""" + +from __future__ import annotations + +import logging + +import redis + +from app.config import config +from app.observability.metrics import record_gateway_thread_lock_contention + +logger = logging.getLogger(__name__) + +_redis_client: redis.Redis | None = None + + +def _redis() -> redis.Redis: + global _redis_client + if _redis_client is None: + _redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True) + return _redis_client + + +def _lock_key(thread_id: int) -> str: + return f"gateway:thread_lock:{thread_id}" + + +def acquire_thread_lock(thread_id: int, ttl: int = 60) -> bool: + acquired = bool(_redis().set(_lock_key(thread_id), "1", nx=True, ex=ttl)) + if not acquired: + record_gateway_thread_lock_contention() + return acquired + + +def release_thread_lock(thread_id: int) -> None: + try: + _redis().delete(_lock_key(thread_id)) + except redis.RedisError as exc: + logger.warning("Failed to release gateway thread lock for %s: %s", thread_id, exc) +