diff --git a/surfsense_backend/app/gateway/agent_invoke.py b/surfsense_backend/app/gateway/agent_invoke.py index b195f3bce..7a2219b1d 100644 --- a/surfsense_backend/app/gateway/agent_invoke.py +++ b/surfsense_backend/app/gateway/agent_invoke.py @@ -11,10 +11,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import ExternalChatBinding, NewChatMessage from app.gateway.auth_invariant import assert_authorization_invariant -from app.gateway.base.translator import GatewayStreamEvent +from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent from app.gateway.bindings import get_or_create_thread_for_binding from app.gateway.hitl_filter import DEFAULT_HITL_TOOL_NAMES -from app.gateway.telegram.translator import TelegramStreamTranslator from app.gateway.thread_lock import acquire_thread_lock, release_thread_lock from app.observability.metrics import record_gateway_turn_latency from app.tasks.chat.stream_new_chat import stream_new_chat @@ -58,7 +57,8 @@ async def call_agent_for_gateway( session: AsyncSession, binding: ExternalChatBinding, user_text: str, - translator: TelegramStreamTranslator, + translator: BaseStreamTranslator, + platform_label: str = "telegram", request_id: str | None = None, ) -> None: user = await assert_authorization_invariant(session, binding) @@ -92,10 +92,10 @@ async def call_agent_for_gateway( NewChatMessage.thread_id == thread.id, NewChatMessage.source == "surfsense", ) - .values(source="telegram") + .values(source=platform_label) ) await session.commit() - record_gateway_turn_latency(0, platform="telegram") + record_gateway_turn_latency(0, platform=platform_label) finally: release_thread_lock(thread.id) diff --git a/surfsense_backend/app/gateway/inbox_processor.py b/surfsense_backend/app/gateway/inbox_processor.py index c40a6c47c..bdf768d61 100644 --- a/surfsense_backend/app/gateway/inbox_processor.py +++ b/surfsense_backend/app/gateway/inbox_processor.py @@ -15,26 +15,19 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.config import config from app.db import ( - ExternalChatBindingState, + ExternalChatAccount, ExternalChatBinding, + ExternalChatBindingState, ExternalChatEventStatus, ExternalChatInboundEvent, ExternalChatPeerKind, - ExternalChatAccount, NewChatThread, async_session_maker, ) -from app.gateway.accounts import account_token from app.gateway.agent_invoke import call_agent_for_gateway +from app.gateway.base.commands import command_name from app.gateway.bindings import get_or_create_thread_for_binding -from app.gateway.telegram.adapter import TelegramAdapter -from app.gateway.telegram.commands import ( - command_name, - handle_help_command, - handle_start_command, - send_unbound_onboarding, -) -from app.gateway.telegram.translator import TelegramStreamTranslator +from app.gateway.registry import resolve_platform_bundle from app.observability.metrics import record_gateway_inbox_processed logger = logging.getLogger(__name__) @@ -150,14 +143,15 @@ async def _dispatch_inbound_event( await session.commit() return - token = account_token(account) - if not token: + try: + bundle = resolve_platform_bundle(account) + except RuntimeError as exc: event.status = ExternalChatEventStatus.FAILED - event.last_error = "missing_telegram_token" + event.last_error = str(exc) await session.commit() return - adapter = TelegramAdapter(token) + adapter = bundle.adapter parsed = adapter.parse_inbound(event.raw_payload or {}) if parsed.external_peer_id is None: event.status = ExternalChatEventStatus.IGNORED @@ -179,7 +173,8 @@ async def _dispatch_inbound_event( binding = result.scalars().first() if parsed.external_peer_kind != ExternalChatPeerKind.DIRECT.value: - await adapter.leave_chat(external_peer_id=parsed.external_peer_id) + if hasattr(adapter, "leave_chat"): + await adapter.leave_chat(external_peer_id=parsed.external_peer_id) event.status = ExternalChatEventStatus.IGNORED event.last_error = "group_rejected" await session.commit() @@ -187,7 +182,7 @@ async def _dispatch_inbound_event( cmd = command_name(parsed.text) if cmd == "/start": - handled = await handle_start_command( + handled = await bundle.commands.handle_start_command( session=session, adapter=adapter, event=parsed ) await session.commit() @@ -195,23 +190,39 @@ async def _dispatch_inbound_event( return if binding is None: - await send_unbound_onboarding( - adapter=adapter, - event=parsed, - dashboard_url=_dashboard_url(), - ) - event.status = ExternalChatEventStatus.IGNORED - event.last_error = "unbound_chat" - await session.commit() - return + if bundle.auto_bind_owner and account.owner_user_id and account.owner_search_space_id: + binding = ExternalChatBinding( + account_id=account.id, + user_id=account.owner_user_id, + search_space_id=account.owner_search_space_id, + state=ExternalChatBindingState.BOUND, + external_peer_id=parsed.external_peer_id, + external_peer_kind=parsed.external_peer_kind, + external_display_name=parsed.display_name, + external_username=parsed.username, + external_metadata=parsed.metadata, + ) + session.add(binding) + await session.flush() + else: + await bundle.commands.send_unbound_onboarding( + adapter=adapter, + event=parsed, + dashboard_url=_dashboard_url(), + ) + event.status = ExternalChatEventStatus.IGNORED + event.last_error = "unbound_chat" + await session.commit() + return event.external_chat_binding_id = binding.id if cmd == "/help": - await handle_help_command(adapter=adapter, event=parsed) - event.status = ExternalChatEventStatus.PROCESSED - await session.commit() - return + handled = await bundle.commands.handle_help_command(adapter=adapter, event=parsed) + if handled: + event.status = ExternalChatEventStatus.PROCESSED + await session.commit() + return if cmd == "/new": binding.new_chat_thread_id = None await adapter.send_message( @@ -231,21 +242,19 @@ async def _dispatch_inbound_event( thread = await get_or_create_thread_for_binding(session, binding) await session.commit() - translator = TelegramStreamTranslator( - adapter=adapter, - external_peer_id=parsed.external_peer_id, - ) + translator = bundle.translator_factory(adapter, parsed) await call_agent_for_gateway( session=session, binding=binding, user_text=parsed.text, translator=translator, + platform_label=bundle.platform_label, request_id=event.request_id or f"gateway:{inbox_id}", ) thread = await session.get(NewChatThread, thread.id) if thread is not None: - thread.source = "telegram" + thread.source = bundle.platform_label await session.commit() diff --git a/surfsense_backend/app/gateway/registry.py b/surfsense_backend/app/gateway/registry.py new file mode 100644 index 000000000..db334b7f1 --- /dev/null +++ b/surfsense_backend/app/gateway/registry.py @@ -0,0 +1,111 @@ +"""Resolve gateway platform implementations from account rows.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass + +from app.db import ExternalChatAccount, ExternalChatAccountMode, ExternalChatPlatform +from app.gateway.accounts import account_token +from app.gateway.base.adapter import BasePlatformAdapter, ParsedInboundEvent +from app.gateway.base.commands import BaseGatewayCommands +from app.gateway.base.translator import BaseStreamTranslator +from app.gateway.telegram.adapter import TelegramAdapter +from app.gateway.telegram.commands import TelegramGatewayCommands +from app.gateway.telegram.translator import TelegramStreamTranslator + +TranslatorFactory = Callable[ + [BasePlatformAdapter, ParsedInboundEvent], + BaseStreamTranslator, +] + + +@dataclass(frozen=True) +class PlatformBundle: + adapter: BasePlatformAdapter + translator_factory: TranslatorFactory + platform_label: str + commands: BaseGatewayCommands + auto_bind_owner: bool = False + + +def _telegram_translator_factory( + adapter: BasePlatformAdapter, + event: ParsedInboundEvent, +) -> BaseStreamTranslator: + if event.external_peer_id is None: + raise RuntimeError("missing_external_peer_id") + return TelegramStreamTranslator( + adapter=adapter, # type: ignore[arg-type] + external_peer_id=event.external_peer_id, + ) + + +def _whatsapp_cloud_translator_factory( + adapter: BasePlatformAdapter, + event: ParsedInboundEvent, +) -> BaseStreamTranslator: + if event.external_peer_id is None: + raise RuntimeError("missing_external_peer_id") + from app.gateway.whatsapp.translator import WhatsAppCloudStreamTranslator + + return WhatsAppCloudStreamTranslator( + adapter=adapter, + external_peer_id=event.external_peer_id, + inbound_message_id=event.external_message_id, + ) + + +def _whatsapp_baileys_translator_factory( + adapter: BasePlatformAdapter, + event: ParsedInboundEvent, +) -> BaseStreamTranslator: + if event.external_peer_id is None: + raise RuntimeError("missing_external_peer_id") + from app.gateway.whatsapp.translator_baileys import WhatsAppBaileysStreamTranslator + + return WhatsAppBaileysStreamTranslator( + adapter=adapter, + external_peer_id=event.external_peer_id, + ) + + +def resolve_platform_bundle(account: ExternalChatAccount) -> PlatformBundle: + if account.platform == ExternalChatPlatform.TELEGRAM: + token = account_token(account) + if not token: + raise RuntimeError("missing_telegram_token") + return PlatformBundle( + adapter=TelegramAdapter(token), + translator_factory=_telegram_translator_factory, + platform_label="telegram", + commands=TelegramGatewayCommands(), + ) + + if account.platform == ExternalChatPlatform.WHATSAPP: + if account.mode == ExternalChatAccountMode.CLOUD_SHARED: + from app.gateway.whatsapp.adapter_cloud import WhatsAppCloudAdapter + from app.gateway.whatsapp.commands import WhatsAppGatewayCommands + from app.gateway.whatsapp.credentials import ( + load_system_whatsapp_credentials, + ) + + return PlatformBundle( + adapter=WhatsAppCloudAdapter(load_system_whatsapp_credentials()), + translator_factory=_whatsapp_cloud_translator_factory, + platform_label="whatsapp", + commands=WhatsAppGatewayCommands(), + auto_bind_owner=False, + ) + if account.mode == ExternalChatAccountMode.SELF_HOST_BYO: + from app.gateway.whatsapp.adapter_baileys import WhatsAppBaileysAdapter + + return PlatformBundle( + adapter=WhatsAppBaileysAdapter(), + translator_factory=_whatsapp_baileys_translator_factory, + platform_label="whatsapp", + commands=BaseGatewayCommands(), + auto_bind_owner=True, + ) + + raise RuntimeError(f"unsupported_gateway_platform:{account.platform.value}:{account.mode.value}")