diff --git a/surfsense_backend/app/gateway/__init__.py b/surfsense_backend/app/gateway/__init__.py new file mode 100644 index 000000000..5cf91505b --- /dev/null +++ b/surfsense_backend/app/gateway/__init__.py @@ -0,0 +1,2 @@ +"""Messaging gateway infrastructure for external chat channels.""" + diff --git a/surfsense_backend/app/gateway/accounts.py b/surfsense_backend/app/gateway/accounts.py new file mode 100644 index 000000000..727d616c1 --- /dev/null +++ b/surfsense_backend/app/gateway/accounts.py @@ -0,0 +1,54 @@ +"""Gateway account helpers.""" + +from __future__ import annotations + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import ( + GatewayAccountMode, + GatewayHealthStatus, + GatewayPlatform, + GatewayPlatformAccount, +) +from app.utils.oauth_security import TokenEncryption + + +def account_token(account: GatewayPlatformAccount) -> str | None: + if account.is_system_account and account.platform == GatewayPlatform.TELEGRAM: + return config.TELEGRAM_SHARED_BOT_TOKEN + if not account.encrypted_credentials: + return None + return TokenEncryption(config.SECRET_KEY or "").decrypt_token( + account.encrypted_credentials + ) + + +async def get_or_create_system_telegram_account( + session: AsyncSession, +) -> GatewayPlatformAccount: + result = await session.execute( + select(GatewayPlatformAccount).where( + GatewayPlatformAccount.platform == GatewayPlatform.TELEGRAM, + GatewayPlatformAccount.is_system_account.is_(True), + ) + ) + account = result.scalars().first() + if account is not None: + return account + account = GatewayPlatformAccount( + platform=GatewayPlatform.TELEGRAM, + mode=GatewayAccountMode.CLOUD_SHARED, + is_system_account=True, + account_metadata={ + "bot_username": config.TELEGRAM_SHARED_BOT_USERNAME, + "webhook_secret": config.TELEGRAM_WEBHOOK_SECRET, + }, + cursor_state={}, + health_status=GatewayHealthStatus.UNKNOWN, + ) + session.add(account) + await session.flush() + return account + diff --git a/surfsense_backend/app/gateway/auth_invariant.py b/surfsense_backend/app/gateway/auth_invariant.py new file mode 100644 index 000000000..414c69c5c --- /dev/null +++ b/surfsense_backend/app/gateway/auth_invariant.py @@ -0,0 +1,55 @@ +"""Authorization invariants for gateway-routed turns.""" + +from __future__ import annotations + +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import GatewayConversationBinding, Permission, User +from app.gateway.bindings import suspend_binding +from app.observability.metrics import record_gateway_auth_invariant_failure +from app.utils.rbac import check_permission, check_search_space_access + + +class GatewaySuspendedError(RuntimeError): + def __init__(self, reason: str) -> None: + self.reason = reason + super().__init__(reason) + + +async def _fail( + session: AsyncSession, + binding: GatewayConversationBinding, + reason: str, +) -> None: + suspend_binding(binding, reason) + record_gateway_auth_invariant_failure(cause=reason) + await session.flush() + raise GatewaySuspendedError(reason) + + +async def assert_authorization_invariant( + session: AsyncSession, + binding: GatewayConversationBinding, +) -> User: + if binding.state != "bound": + await _fail(session, binding, "binding_not_bound") + + user = await session.get(User, binding.user_id) + if user is None: + await _fail(session, binding, "owner_missing") + + try: + await check_search_space_access(session, user, binding.search_space_id) + await check_permission( + session, + user, + binding.search_space_id, + Permission.CHATS_CREATE.value, + "Gateway owner no longer has permission to chat in this search space", + ) + except HTTPException as exc: + await _fail(session, binding, f"rbac_{exc.status_code}") + + return user + diff --git a/surfsense_backend/app/gateway/base/__init__.py b/surfsense_backend/app/gateway/base/__init__.py new file mode 100644 index 000000000..962d068b6 --- /dev/null +++ b/surfsense_backend/app/gateway/base/__init__.py @@ -0,0 +1,2 @@ +"""Base gateway interfaces.""" + diff --git a/surfsense_backend/app/gateway/base/adapter.py b/surfsense_backend/app/gateway/base/adapter.py new file mode 100644 index 000000000..caf351c05 --- /dev/null +++ b/surfsense_backend/app/gateway/base/adapter.py @@ -0,0 +1,70 @@ +"""Platform adapter interfaces for messaging gateways.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class ParsedInboundEvent: + platform: str + event_kind: str + external_peer_id: str | None + external_peer_kind: str + external_message_id: str | None + external_user_id: str | None + text: str | None + raw_payload: dict[str, Any] + display_name: str | None = None + username: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class PlatformSendResult: + external_message_id: str + raw_response: dict[str, Any] = field(default_factory=dict) + + +class BasePlatformAdapter(ABC): + platform: str + + @abstractmethod + def parse_inbound(self, raw_payload: dict[str, Any]) -> ParsedInboundEvent: + """Parse a provider webhook/update into the gateway's normalized shape.""" + + @abstractmethod + async def send_message( + self, + *, + external_peer_id: str, + text: str, + parse_mode: str | None = None, + reply_to_message_id: str | None = None, + ) -> PlatformSendResult: + """Send a new platform message.""" + + @abstractmethod + async def edit_message( + self, + *, + external_peer_id: str, + external_message_id: str, + text: str, + parse_mode: str | None = None, + ) -> PlatformSendResult: + """Edit an existing platform message.""" + + @abstractmethod + async def validate_credentials(self) -> dict[str, Any]: + """Validate configured credentials and return account metadata.""" + + async def fetch_updates(self, *, offset: int | None) -> AsyncIterator[dict[str, Any]]: + """Yield provider updates for long-polling adapters.""" + if False: + yield {} # pragma: no cover + raise NotImplementedError("This adapter does not support long-polling") + diff --git a/surfsense_backend/app/gateway/base/identity.py b/surfsense_backend/app/gateway/base/identity.py new file mode 100644 index 000000000..608ae41c1 --- /dev/null +++ b/surfsense_backend/app/gateway/base/identity.py @@ -0,0 +1,19 @@ +"""Gateway identity helpers.""" + +from __future__ import annotations + +import hashlib + + +def normalize_external_peer_id(value: str | int | None) -> str | None: + if value is None: + return None + return str(value).strip() + + +def hash_external_id(value: str | int | None) -> str | None: + normalized = normalize_external_peer_id(value) + if not normalized: + return None + return hashlib.sha256(normalized.encode("utf-8")).hexdigest() + diff --git a/surfsense_backend/app/gateway/base/translator.py b/surfsense_backend/app/gateway/base/translator.py new file mode 100644 index 000000000..af72188e9 --- /dev/null +++ b/surfsense_backend/app/gateway/base/translator.py @@ -0,0 +1,28 @@ +"""Base stream translator for platform-specific outbound UX.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class GatewayStreamEvent: + """Small provider-neutral event shape consumed by translators. + + The existing chat stack emits Vercel/assistant-ui events. Gateway code + normalizes the subset it needs into this shape before handing it to the + platform translator. + """ + + type: str + data: dict[str, Any] = field(default_factory=dict) + + +class BaseStreamTranslator(ABC): + @abstractmethod + async def translate(self, events: AsyncIterator[GatewayStreamEvent]) -> None: + """Consume agent stream events and emit platform messages.""" + diff --git a/surfsense_backend/app/gateway/bindings.py b/surfsense_backend/app/gateway/bindings.py new file mode 100644 index 000000000..6f2b641f7 --- /dev/null +++ b/surfsense_backend/app/gateway/bindings.py @@ -0,0 +1,62 @@ +"""Gateway binding helpers.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + ChatVisibility, + GatewayBindingState, + GatewayConversationBinding, + NewChatThread, +) + + +async def get_or_create_thread_for_binding( + session: AsyncSession, + binding: GatewayConversationBinding, +) -> NewChatThread: + if binding.active_thread_id is not None: + result = await session.execute( + select(NewChatThread).where(NewChatThread.id == binding.active_thread_id) + ) + thread = result.scalars().first() + if thread is not None and not thread.archived: + return thread + + thread = NewChatThread( + title="Telegram chat", + search_space_id=binding.search_space_id, + created_by_id=binding.user_id, + visibility=ChatVisibility.PRIVATE, + source="telegram", + binding_id=binding.id, + ) + session.add(thread) + await session.flush() + binding.active_thread_id = thread.id + return thread + + +def suspend_binding(binding: GatewayConversationBinding, reason: str) -> None: + now = datetime.now(UTC) + binding.state = GatewayBindingState.SUSPENDED + binding.suspended_at = now + binding.suspended_reason = reason + + +def revoke_binding(binding: GatewayConversationBinding) -> None: + now = datetime.now(UTC) + binding.state = GatewayBindingState.REVOKED + binding.revoked_at = now + binding.active_thread_id = None + + +def resume_binding(binding: GatewayConversationBinding) -> None: + binding.state = GatewayBindingState.BOUND + binding.suspended_at = None + binding.suspended_reason = None + diff --git a/surfsense_backend/app/gateway/hitl_filter.py b/surfsense_backend/app/gateway/hitl_filter.py new file mode 100644 index 000000000..e3acc6d42 --- /dev/null +++ b/surfsense_backend/app/gateway/hitl_filter.py @@ -0,0 +1,35 @@ +"""Filter approval-required tools from gateway agent invocations.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +DEFAULT_HITL_TOOL_NAMES = { + "delete_document", + "delete_folder", + "delete_note", + "delete_report", + "delete_connector", + "send_email", + "share_chat", +} + + +def _tool_name(tool: Any) -> str | None: + if isinstance(tool, str): + return tool + return getattr(tool, "name", None) or getattr(tool, "__name__", None) + + +def filter_hitl_tools( + toolkit: Iterable[Any] | None, + *, + blocked_names: set[str] | None = None, +) -> list[Any] | None: + """Return a toolkit with known approval-required tools removed.""" + if toolkit is None: + return None + blocked = blocked_names or DEFAULT_HITL_TOOL_NAMES + return [tool for tool in toolkit if (_tool_name(tool) or "") not in blocked] + diff --git a/surfsense_backend/app/gateway/inbox.py b/surfsense_backend/app/gateway/inbox.py new file mode 100644 index 000000000..c98ee5977 --- /dev/null +++ b/surfsense_backend/app/gateway/inbox.py @@ -0,0 +1,44 @@ +"""Durable gateway inbox helpers.""" + +from __future__ import annotations + +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import GatewayInboundEvent, GatewayPlatform + + +def telegram_event_dedupe_key(update_id: int | str) -> str: + return f"update:{update_id}" + + +async def persist_inbound_event( + session: AsyncSession, + *, + account_id: int, + platform: GatewayPlatform, + event_dedupe_key: str, + event_kind: str, + raw_payload: dict, + external_event_id: str | None = None, + external_message_id: str | None = None, +) -> int | None: + stmt = ( + insert(GatewayInboundEvent) + .values( + account_id=account_id, + platform=platform, + event_dedupe_key=event_dedupe_key, + external_event_id=external_event_id, + external_message_id=external_message_id, + event_kind=event_kind, + raw_payload=raw_payload, + ) + .on_conflict_do_nothing( + index_elements=["account_id", "event_dedupe_key"], + ) + .returning(GatewayInboundEvent.id) + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() + diff --git a/surfsense_backend/app/gateway/pairing.py b/surfsense_backend/app/gateway/pairing.py new file mode 100644 index 000000000..55232022e --- /dev/null +++ b/surfsense_backend/app/gateway/pairing.py @@ -0,0 +1,54 @@ +"""Pairing code lifecycle for gateway bindings.""" + +from __future__ import annotations + +import secrets +from datetime import UTC, datetime, timedelta + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import GatewayBindingState, GatewayConversationBinding + +PAIRING_CODE_TTL = timedelta(minutes=10) + + +def generate_pairing_code() -> str: + return secrets.token_urlsafe(6) + + +def pairing_expires_at() -> datetime: + return datetime.now(UTC) + PAIRING_CODE_TTL + + +async def redeem_pairing_code( + session: AsyncSession, + *, + code: str, + external_peer_id: str, + external_peer_kind: str, + external_display_name: str | None, + external_username: str | None, + external_metadata: dict | None = None, +) -> GatewayConversationBinding | None: + result = await session.execute( + select(GatewayConversationBinding).where( + GatewayConversationBinding.pairing_code == code, + GatewayConversationBinding.state == GatewayBindingState.PENDING, + GatewayConversationBinding.pairing_code_expires_at > datetime.now(UTC), + ) + ) + binding = result.scalars().first() + if binding is None: + return None + + binding.state = GatewayBindingState.BOUND + binding.pairing_code = None + binding.pairing_code_expires_at = None + binding.external_peer_id = external_peer_id + binding.external_peer_kind = external_peer_kind + binding.external_display_name = external_display_name + binding.external_username = external_username + binding.external_metadata = external_metadata or {} + return binding + diff --git a/surfsense_backend/tests/unit/gateway/test_hitl_filter.py b/surfsense_backend/tests/unit/gateway/test_hitl_filter.py new file mode 100644 index 000000000..90f94b6ab --- /dev/null +++ b/surfsense_backend/tests/unit/gateway/test_hitl_filter.py @@ -0,0 +1,15 @@ +from app.gateway.hitl_filter import filter_hitl_tools + + +class Tool: + def __init__(self, name: str) -> None: + self.name = name + + +def test_filter_hitl_tools_removes_known_approval_tools(): + tools = [Tool("delete_document"), Tool("search"), "send_email", "summarize"] + + filtered = filter_hitl_tools(tools) + + assert [getattr(tool, "name", tool) for tool in filtered] == ["search", "summarize"] + diff --git a/surfsense_backend/tests/unit/gateway/test_pairing.py b/surfsense_backend/tests/unit/gateway/test_pairing.py new file mode 100644 index 000000000..c50bd6b7c --- /dev/null +++ b/surfsense_backend/tests/unit/gateway/test_pairing.py @@ -0,0 +1,41 @@ +from datetime import UTC, datetime, timedelta + +import pytest + +from app.db import GatewayBindingState +from app.gateway.pairing import generate_pairing_code, redeem_pairing_code + + +def test_generate_pairing_code_is_short_display_token(): + code = generate_pairing_code() + + assert len(code) >= 8 + assert "\n" not in code + + +@pytest.mark.asyncio +async def test_redeem_pairing_code_binds_pending_row(mocker): + binding = mocker.Mock() + binding.state = GatewayBindingState.PENDING + binding.pairing_code_expires_at = datetime.now(UTC) + timedelta(minutes=1) + scalars = mocker.Mock() + scalars.first.return_value = binding + result = mocker.Mock() + result.scalars.return_value = scalars + session = mocker.AsyncMock() + session.execute.return_value = result + + redeemed = await redeem_pairing_code( + session, + code="abc", + external_peer_id="telegram:123", + external_peer_kind="direct", + external_display_name="Anish", + external_username="anish", + ) + + assert redeemed is binding + assert binding.state == GatewayBindingState.BOUND + assert binding.external_peer_id == "telegram:123" + assert binding.pairing_code is None +