mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
feat(gateway): add Discord platform adapter
This commit is contained in:
parent
68da295b5d
commit
bc8a285187
4 changed files with 337 additions and 0 deletions
1
surfsense_backend/app/gateway/discord/__init__.py
Normal file
1
surfsense_backend/app/gateway/discord/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Discord gateway platform integration."""
|
||||||
135
surfsense_backend/app/gateway/discord/adapter.py
Normal file
135
surfsense_backend/app/gateway/discord/adapter.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
"""Discord platform adapter for bot mentions and replies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.gateway.base.adapter import (
|
||||||
|
BasePlatformAdapter,
|
||||||
|
ParsedInboundEvent,
|
||||||
|
PlatformSendResult,
|
||||||
|
)
|
||||||
|
from app.gateway.discord.client import DiscordGatewayClient
|
||||||
|
|
||||||
|
MENTION_RE = re.compile(r"<@!?\d+>\s*")
|
||||||
|
|
||||||
|
|
||||||
|
def discord_user_peer_id(guild_id: str, discord_user_id: str) -> str:
|
||||||
|
return f"discord_user:{guild_id}:{discord_user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def discord_thread_peer_id(guild_id: str, channel_id: str, thread_key: str) -> str:
|
||||||
|
return f"discord_thread:{guild_id}:{channel_id}:{thread_key}"
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordAdapter(BasePlatformAdapter):
|
||||||
|
platform = "discord"
|
||||||
|
|
||||||
|
def __init__(self, bot_token: str, *, bot_user_id: str | None = None) -> None:
|
||||||
|
self.bot_user_id = bot_user_id
|
||||||
|
self.client = DiscordGatewayClient(bot_token)
|
||||||
|
|
||||||
|
def parse_inbound(self, raw_payload: dict[str, Any]) -> ParsedInboundEvent:
|
||||||
|
event = raw_payload.get("event") or raw_payload
|
||||||
|
event_kind = str(raw_payload.get("type") or event.get("type") or "message")
|
||||||
|
guild_id = str(event.get("guild_id") or "")
|
||||||
|
channel_id = str(event.get("channel_id") or "")
|
||||||
|
author = event.get("author") or {}
|
||||||
|
discord_user_id = str(author.get("id") or event.get("author_id") or "")
|
||||||
|
message_id = str(event.get("id") or event.get("message_id") or "")
|
||||||
|
bot_user_id = self.bot_user_id or str(raw_payload.get("bot_user_id") or "")
|
||||||
|
|
||||||
|
if not guild_id or not channel_id or not discord_user_id or not message_id:
|
||||||
|
return ParsedInboundEvent(
|
||||||
|
platform=self.platform,
|
||||||
|
event_kind=event_kind,
|
||||||
|
external_peer_id=None,
|
||||||
|
external_peer_kind="unknown",
|
||||||
|
external_message_id=message_id or None,
|
||||||
|
external_user_id=discord_user_id or None,
|
||||||
|
text=None,
|
||||||
|
raw_payload=raw_payload,
|
||||||
|
metadata={
|
||||||
|
"guild_id": guild_id,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"bot_user_id": bot_user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
text = str(event.get("content") or "")
|
||||||
|
if bot_user_id:
|
||||||
|
text = text.replace(f"<@{bot_user_id}>", "")
|
||||||
|
text = text.replace(f"<@!{bot_user_id}>", "")
|
||||||
|
text = MENTION_RE.sub("", text).strip()
|
||||||
|
|
||||||
|
thread_key = str(
|
||||||
|
event.get("thread_id")
|
||||||
|
or (event.get("message_reference") or {}).get("message_id")
|
||||||
|
or message_id
|
||||||
|
)
|
||||||
|
thread_peer_id = discord_thread_peer_id(guild_id, channel_id, thread_key)
|
||||||
|
user_peer_id = discord_user_peer_id(guild_id, discord_user_id)
|
||||||
|
mentions = event.get("mentions") or []
|
||||||
|
mentions_bot = bool(
|
||||||
|
bot_user_id
|
||||||
|
and any(str(mention.get("id")) == bot_user_id for mention in mentions)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ParsedInboundEvent(
|
||||||
|
platform=self.platform,
|
||||||
|
event_kind=event_kind,
|
||||||
|
external_peer_id=thread_peer_id,
|
||||||
|
external_peer_kind="channel",
|
||||||
|
external_message_id=message_id,
|
||||||
|
external_user_id=discord_user_id,
|
||||||
|
text=text,
|
||||||
|
raw_payload=raw_payload,
|
||||||
|
display_name=event.get("channel_name"),
|
||||||
|
username=author.get("username") or discord_user_id,
|
||||||
|
metadata={
|
||||||
|
"guild_id": guild_id,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"discord_user_id": discord_user_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
"thread_key": thread_key,
|
||||||
|
"bot_user_id": bot_user_id,
|
||||||
|
"discord_user_peer_id": user_peer_id,
|
||||||
|
"discord_thread_peer_id": thread_peer_id,
|
||||||
|
"mentions_bot": mentions_bot,
|
||||||
|
"is_dm": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
external_peer_id: str,
|
||||||
|
text: str,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
reply_to_message_id: str | None = None,
|
||||||
|
) -> PlatformSendResult:
|
||||||
|
del parse_mode
|
||||||
|
return await self.client.send_message(
|
||||||
|
channel_id=external_peer_id,
|
||||||
|
content=text,
|
||||||
|
reply_to_message_id=reply_to_message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def edit_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
external_peer_id: str,
|
||||||
|
external_message_id: str,
|
||||||
|
text: str,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
) -> PlatformSendResult:
|
||||||
|
del parse_mode
|
||||||
|
return await self.client.update_message(
|
||||||
|
channel_id=external_peer_id,
|
||||||
|
message_id=external_message_id,
|
||||||
|
content=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def validate_credentials(self) -> dict[str, Any]:
|
||||||
|
return await self.client.validate()
|
||||||
109
surfsense_backend/app/gateway/discord/client.py
Normal file
109
surfsense_backend/app/gateway/discord/client.py
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""Discord REST API client for gateway bot operations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.gateway.base.adapter import PlatformSendResult
|
||||||
|
|
||||||
|
DISCORD_API = "https://discord.com/api/v10"
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordGatewayClient:
|
||||||
|
def __init__(self, bot_token: str) -> None:
|
||||||
|
self.bot_token = bot_token
|
||||||
|
|
||||||
|
async def api_call(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
payload: dict[str, Any] | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
retry_rate_limit: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
response = await client.request(
|
||||||
|
method,
|
||||||
|
f"{DISCORD_API}{path}",
|
||||||
|
json=payload,
|
||||||
|
params=params,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bot {self.bot_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if response.status_code == 429 and retry_rate_limit:
|
||||||
|
data = response.json()
|
||||||
|
retry_after = float(data.get("retry_after") or 1.0)
|
||||||
|
await asyncio.sleep(min(retry_after, 5.0))
|
||||||
|
return await self.api_call(
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
payload=payload,
|
||||||
|
params=params,
|
||||||
|
retry_rate_limit=False,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
if not response.content:
|
||||||
|
return {}
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
content: str,
|
||||||
|
reply_to_message_id: str | None = None,
|
||||||
|
) -> PlatformSendResult:
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"content": content,
|
||||||
|
"allowed_mentions": {"parse": []},
|
||||||
|
}
|
||||||
|
if reply_to_message_id:
|
||||||
|
payload["message_reference"] = {
|
||||||
|
"message_id": reply_to_message_id,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"fail_if_not_exists": False,
|
||||||
|
}
|
||||||
|
data = await self.api_call(
|
||||||
|
"POST",
|
||||||
|
f"/channels/{channel_id}/messages",
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
return PlatformSendResult(
|
||||||
|
external_message_id=str(data.get("id", "")),
|
||||||
|
raw_response=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
message_id: str,
|
||||||
|
content: str,
|
||||||
|
) -> PlatformSendResult:
|
||||||
|
data = await self.api_call(
|
||||||
|
"PATCH",
|
||||||
|
f"/channels/{channel_id}/messages/{message_id}",
|
||||||
|
payload={"content": content, "allowed_mentions": {"parse": []}},
|
||||||
|
)
|
||||||
|
return PlatformSendResult(
|
||||||
|
external_message_id=str(data.get("id") or message_id),
|
||||||
|
raw_response=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def validate(self) -> dict[str, Any]:
|
||||||
|
data = await self.api_call("GET", "/users/@me")
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"bot_user_id": data.get("id"),
|
||||||
|
"bot_username": data.get("username"),
|
||||||
|
"global_name": data.get("global_name"),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_guild(self, guild_id: str) -> dict[str, Any]:
|
||||||
|
return await self.api_call("GET", f"/guilds/{guild_id}")
|
||||||
92
surfsense_backend/tests/unit/gateway/test_discord_adapter.py
Normal file
92
surfsense_backend/tests/unit/gateway/test_discord_adapter.py
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.gateway.base.adapter import PlatformSendResult
|
||||||
|
from app.gateway.discord.adapter import DiscordAdapter
|
||||||
|
|
||||||
|
|
||||||
|
def _discord_payload(content: str = "<@999> summarize this channel"):
|
||||||
|
return {
|
||||||
|
"type": "message",
|
||||||
|
"bot_user_id": "999",
|
||||||
|
"event": {
|
||||||
|
"type": "message",
|
||||||
|
"id": "111",
|
||||||
|
"guild_id": "222",
|
||||||
|
"guild_name": "SurfSense Guild",
|
||||||
|
"channel_id": "333",
|
||||||
|
"channel_name": "general",
|
||||||
|
"content": content,
|
||||||
|
"author": {"id": "444", "username": "anish", "bot": False},
|
||||||
|
"mentions": [{"id": "999", "username": "SurfSense"}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_discord_adapter_parses_mention_and_strips_bot_mention():
|
||||||
|
adapter = DiscordAdapter("discord-token", bot_user_id="999")
|
||||||
|
|
||||||
|
parsed = adapter.parse_inbound(_discord_payload())
|
||||||
|
|
||||||
|
assert parsed.platform == "discord"
|
||||||
|
assert parsed.text == "summarize this channel"
|
||||||
|
assert parsed.external_peer_id == "discord_thread:222:333:111"
|
||||||
|
assert parsed.metadata["discord_user_peer_id"] == "discord_user:222:444"
|
||||||
|
assert parsed.metadata["discord_thread_peer_id"] == "discord_thread:222:333:111"
|
||||||
|
assert parsed.metadata["mentions_bot"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_discord_adapter_strips_nickname_mention():
|
||||||
|
adapter = DiscordAdapter("discord-token", bot_user_id="999")
|
||||||
|
|
||||||
|
parsed = adapter.parse_inbound(_discord_payload("<@!999> continue"))
|
||||||
|
|
||||||
|
assert parsed.text == "continue"
|
||||||
|
|
||||||
|
|
||||||
|
def test_discord_adapter_uses_message_reference_as_thread_key():
|
||||||
|
adapter = DiscordAdapter("discord-token", bot_user_id="999")
|
||||||
|
payload = _discord_payload("<@999> continue")
|
||||||
|
payload["event"]["id"] = "112"
|
||||||
|
payload["event"]["message_reference"] = {
|
||||||
|
"message_id": "111",
|
||||||
|
"channel_id": "333",
|
||||||
|
"guild_id": "222",
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed = adapter.parse_inbound(payload)
|
||||||
|
|
||||||
|
assert parsed.external_peer_id == "discord_thread:222:333:111"
|
||||||
|
assert parsed.metadata["message_id"] == "112"
|
||||||
|
assert parsed.metadata["thread_key"] == "111"
|
||||||
|
|
||||||
|
|
||||||
|
def test_discord_adapter_returns_missing_peer_for_incomplete_payload():
|
||||||
|
adapter = DiscordAdapter("discord-token", bot_user_id="999")
|
||||||
|
|
||||||
|
parsed = adapter.parse_inbound({"event": {"id": "111"}})
|
||||||
|
|
||||||
|
assert parsed.external_peer_id is None
|
||||||
|
assert parsed.external_peer_kind == "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discord_adapter_sends_message(mocker):
|
||||||
|
adapter = DiscordAdapter("discord-token", bot_user_id="999")
|
||||||
|
adapter.client.send_message = mocker.AsyncMock(
|
||||||
|
return_value=PlatformSendResult(external_message_id="555")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await adapter.send_message(
|
||||||
|
external_peer_id="333",
|
||||||
|
text="hello",
|
||||||
|
reply_to_message_id="111",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.external_message_id == "555"
|
||||||
|
adapter.client.send_message.assert_awaited_once_with(
|
||||||
|
channel_id="333",
|
||||||
|
content="hello",
|
||||||
|
reply_to_message_id="111",
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue