diff --git a/surfsense_backend/app/gateway/discord/__init__.py b/surfsense_backend/app/gateway/discord/__init__.py new file mode 100644 index 000000000..1dd0edc96 --- /dev/null +++ b/surfsense_backend/app/gateway/discord/__init__.py @@ -0,0 +1 @@ +"""Discord gateway platform integration.""" diff --git a/surfsense_backend/app/gateway/discord/adapter.py b/surfsense_backend/app/gateway/discord/adapter.py new file mode 100644 index 000000000..60db895fe --- /dev/null +++ b/surfsense_backend/app/gateway/discord/adapter.py @@ -0,0 +1,135 @@ +"""Discord platform adapter for bot mentions and replies.""" + +from __future__ import annotations + +import re +from typing import Any + +from app.gateway.base.adapter import ( + BasePlatformAdapter, + ParsedInboundEvent, + PlatformSendResult, +) +from app.gateway.discord.client import DiscordGatewayClient + +MENTION_RE = re.compile(r"<@!?\d+>\s*") + + +def discord_user_peer_id(guild_id: str, discord_user_id: str) -> str: + return f"discord_user:{guild_id}:{discord_user_id}" + + +def discord_thread_peer_id(guild_id: str, channel_id: str, thread_key: str) -> str: + return f"discord_thread:{guild_id}:{channel_id}:{thread_key}" + + +class DiscordAdapter(BasePlatformAdapter): + platform = "discord" + + def __init__(self, bot_token: str, *, bot_user_id: str | None = None) -> None: + self.bot_user_id = bot_user_id + self.client = DiscordGatewayClient(bot_token) + + def parse_inbound(self, raw_payload: dict[str, Any]) -> ParsedInboundEvent: + event = raw_payload.get("event") or raw_payload + event_kind = str(raw_payload.get("type") or event.get("type") or "message") + guild_id = str(event.get("guild_id") or "") + channel_id = str(event.get("channel_id") or "") + author = event.get("author") or {} + discord_user_id = str(author.get("id") or event.get("author_id") or "") + message_id = str(event.get("id") or event.get("message_id") or "") + bot_user_id = self.bot_user_id or str(raw_payload.get("bot_user_id") or "") + + if not guild_id or not channel_id or not discord_user_id or not message_id: + return ParsedInboundEvent( + platform=self.platform, + event_kind=event_kind, + external_peer_id=None, + external_peer_kind="unknown", + external_message_id=message_id or None, + external_user_id=discord_user_id or None, + text=None, + raw_payload=raw_payload, + metadata={ + "guild_id": guild_id, + "channel_id": channel_id, + "bot_user_id": bot_user_id, + }, + ) + + text = str(event.get("content") or "") + if bot_user_id: + text = text.replace(f"<@{bot_user_id}>", "") + text = text.replace(f"<@!{bot_user_id}>", "") + text = MENTION_RE.sub("", text).strip() + + thread_key = str( + event.get("thread_id") + or (event.get("message_reference") or {}).get("message_id") + or message_id + ) + thread_peer_id = discord_thread_peer_id(guild_id, channel_id, thread_key) + user_peer_id = discord_user_peer_id(guild_id, discord_user_id) + mentions = event.get("mentions") or [] + mentions_bot = bool( + bot_user_id + and any(str(mention.get("id")) == bot_user_id for mention in mentions) + ) + + return ParsedInboundEvent( + platform=self.platform, + event_kind=event_kind, + external_peer_id=thread_peer_id, + external_peer_kind="channel", + external_message_id=message_id, + external_user_id=discord_user_id, + text=text, + raw_payload=raw_payload, + display_name=event.get("channel_name"), + username=author.get("username") or discord_user_id, + metadata={ + "guild_id": guild_id, + "channel_id": channel_id, + "discord_user_id": discord_user_id, + "message_id": message_id, + "thread_key": thread_key, + "bot_user_id": bot_user_id, + "discord_user_peer_id": user_peer_id, + "discord_thread_peer_id": thread_peer_id, + "mentions_bot": mentions_bot, + "is_dm": False, + }, + ) + + async def send_message( + self, + *, + external_peer_id: str, + text: str, + parse_mode: str | None = None, + reply_to_message_id: str | None = None, + ) -> PlatformSendResult: + del parse_mode + return await self.client.send_message( + channel_id=external_peer_id, + content=text, + reply_to_message_id=reply_to_message_id, + ) + + async def edit_message( + self, + *, + external_peer_id: str, + external_message_id: str, + text: str, + parse_mode: str | None = None, + ) -> PlatformSendResult: + del parse_mode + return await self.client.update_message( + channel_id=external_peer_id, + message_id=external_message_id, + content=text, + ) + + async def validate_credentials(self) -> dict[str, Any]: + return await self.client.validate() diff --git a/surfsense_backend/app/gateway/discord/client.py b/surfsense_backend/app/gateway/discord/client.py new file mode 100644 index 000000000..206abaa5f --- /dev/null +++ b/surfsense_backend/app/gateway/discord/client.py @@ -0,0 +1,109 @@ +"""Discord REST API client for gateway bot operations.""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import httpx + +from app.gateway.base.adapter import PlatformSendResult + +DISCORD_API = "https://discord.com/api/v10" + + +class DiscordGatewayClient: + def __init__(self, bot_token: str) -> None: + self.bot_token = bot_token + + async def api_call( + self, + method: str, + path: str, + *, + payload: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + retry_rate_limit: bool = True, + ) -> dict[str, Any]: + async with httpx.AsyncClient(timeout=20.0) as client: + response = await client.request( + method, + f"{DISCORD_API}{path}", + json=payload, + params=params, + headers={ + "Authorization": f"Bot {self.bot_token}", + "Content-Type": "application/json", + }, + ) + if response.status_code == 429 and retry_rate_limit: + data = response.json() + retry_after = float(data.get("retry_after") or 1.0) + await asyncio.sleep(min(retry_after, 5.0)) + return await self.api_call( + method, + path, + payload=payload, + params=params, + retry_rate_limit=False, + ) + response.raise_for_status() + if not response.content: + return {} + return response.json() + + async def send_message( + self, + *, + channel_id: str, + content: str, + reply_to_message_id: str | None = None, + ) -> PlatformSendResult: + payload: dict[str, Any] = { + "content": content, + "allowed_mentions": {"parse": []}, + } + if reply_to_message_id: + payload["message_reference"] = { + "message_id": reply_to_message_id, + "channel_id": channel_id, + "fail_if_not_exists": False, + } + data = await self.api_call( + "POST", + f"/channels/{channel_id}/messages", + payload=payload, + ) + return PlatformSendResult( + external_message_id=str(data.get("id", "")), + raw_response=data, + ) + + async def update_message( + self, + *, + channel_id: str, + message_id: str, + content: str, + ) -> PlatformSendResult: + data = await self.api_call( + "PATCH", + f"/channels/{channel_id}/messages/{message_id}", + payload={"content": content, "allowed_mentions": {"parse": []}}, + ) + return PlatformSendResult( + external_message_id=str(data.get("id") or message_id), + raw_response=data, + ) + + async def validate(self) -> dict[str, Any]: + data = await self.api_call("GET", "/users/@me") + return { + "ok": True, + "bot_user_id": data.get("id"), + "bot_username": data.get("username"), + "global_name": data.get("global_name"), + } + + async def get_guild(self, guild_id: str) -> dict[str, Any]: + return await self.api_call("GET", f"/guilds/{guild_id}") diff --git a/surfsense_backend/tests/unit/gateway/test_discord_adapter.py b/surfsense_backend/tests/unit/gateway/test_discord_adapter.py new file mode 100644 index 000000000..c6790f20b --- /dev/null +++ b/surfsense_backend/tests/unit/gateway/test_discord_adapter.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import pytest + +from app.gateway.base.adapter import PlatformSendResult +from app.gateway.discord.adapter import DiscordAdapter + + +def _discord_payload(content: str = "<@999> summarize this channel"): + return { + "type": "message", + "bot_user_id": "999", + "event": { + "type": "message", + "id": "111", + "guild_id": "222", + "guild_name": "SurfSense Guild", + "channel_id": "333", + "channel_name": "general", + "content": content, + "author": {"id": "444", "username": "anish", "bot": False}, + "mentions": [{"id": "999", "username": "SurfSense"}], + }, + } + + +def test_discord_adapter_parses_mention_and_strips_bot_mention(): + adapter = DiscordAdapter("discord-token", bot_user_id="999") + + parsed = adapter.parse_inbound(_discord_payload()) + + assert parsed.platform == "discord" + assert parsed.text == "summarize this channel" + assert parsed.external_peer_id == "discord_thread:222:333:111" + assert parsed.metadata["discord_user_peer_id"] == "discord_user:222:444" + assert parsed.metadata["discord_thread_peer_id"] == "discord_thread:222:333:111" + assert parsed.metadata["mentions_bot"] is True + + +def test_discord_adapter_strips_nickname_mention(): + adapter = DiscordAdapter("discord-token", bot_user_id="999") + + parsed = adapter.parse_inbound(_discord_payload("<@!999> continue")) + + assert parsed.text == "continue" + + +def test_discord_adapter_uses_message_reference_as_thread_key(): + adapter = DiscordAdapter("discord-token", bot_user_id="999") + payload = _discord_payload("<@999> continue") + payload["event"]["id"] = "112" + payload["event"]["message_reference"] = { + "message_id": "111", + "channel_id": "333", + "guild_id": "222", + } + + parsed = adapter.parse_inbound(payload) + + assert parsed.external_peer_id == "discord_thread:222:333:111" + assert parsed.metadata["message_id"] == "112" + assert parsed.metadata["thread_key"] == "111" + + +def test_discord_adapter_returns_missing_peer_for_incomplete_payload(): + adapter = DiscordAdapter("discord-token", bot_user_id="999") + + parsed = adapter.parse_inbound({"event": {"id": "111"}}) + + assert parsed.external_peer_id is None + assert parsed.external_peer_kind == "unknown" + + +@pytest.mark.asyncio +async def test_discord_adapter_sends_message(mocker): + adapter = DiscordAdapter("discord-token", bot_user_id="999") + adapter.client.send_message = mocker.AsyncMock( + return_value=PlatformSendResult(external_message_id="555") + ) + + result = await adapter.send_message( + external_peer_id="333", + text="hello", + reply_to_message_id="111", + ) + + assert result.external_message_id == "555" + adapter.client.send_message.assert_awaited_once_with( + channel_id="333", + content="hello", + reply_to_message_id="111", + )