feat(gateway): add Discord platform adapter

This commit is contained in:
Anish Sarkar 2026-06-01 20:58:50 +05:30
parent 68da295b5d
commit bc8a285187
4 changed files with 337 additions and 0 deletions

View file

@ -0,0 +1 @@
"""Discord gateway platform integration."""

View file

@ -0,0 +1,135 @@
"""Discord platform adapter for bot mentions and replies."""
from __future__ import annotations
import re
from typing import Any
from app.gateway.base.adapter import (
BasePlatformAdapter,
ParsedInboundEvent,
PlatformSendResult,
)
from app.gateway.discord.client import DiscordGatewayClient
MENTION_RE = re.compile(r"<@!?\d+>\s*")
def discord_user_peer_id(guild_id: str, discord_user_id: str) -> str:
return f"discord_user:{guild_id}:{discord_user_id}"
def discord_thread_peer_id(guild_id: str, channel_id: str, thread_key: str) -> str:
return f"discord_thread:{guild_id}:{channel_id}:{thread_key}"
class DiscordAdapter(BasePlatformAdapter):
platform = "discord"
def __init__(self, bot_token: str, *, bot_user_id: str | None = None) -> None:
self.bot_user_id = bot_user_id
self.client = DiscordGatewayClient(bot_token)
def parse_inbound(self, raw_payload: dict[str, Any]) -> ParsedInboundEvent:
event = raw_payload.get("event") or raw_payload
event_kind = str(raw_payload.get("type") or event.get("type") or "message")
guild_id = str(event.get("guild_id") or "")
channel_id = str(event.get("channel_id") or "")
author = event.get("author") or {}
discord_user_id = str(author.get("id") or event.get("author_id") or "")
message_id = str(event.get("id") or event.get("message_id") or "")
bot_user_id = self.bot_user_id or str(raw_payload.get("bot_user_id") or "")
if not guild_id or not channel_id or not discord_user_id or not message_id:
return ParsedInboundEvent(
platform=self.platform,
event_kind=event_kind,
external_peer_id=None,
external_peer_kind="unknown",
external_message_id=message_id or None,
external_user_id=discord_user_id or None,
text=None,
raw_payload=raw_payload,
metadata={
"guild_id": guild_id,
"channel_id": channel_id,
"bot_user_id": bot_user_id,
},
)
text = str(event.get("content") or "")
if bot_user_id:
text = text.replace(f"<@{bot_user_id}>", "")
text = text.replace(f"<@!{bot_user_id}>", "")
text = MENTION_RE.sub("", text).strip()
thread_key = str(
event.get("thread_id")
or (event.get("message_reference") or {}).get("message_id")
or message_id
)
thread_peer_id = discord_thread_peer_id(guild_id, channel_id, thread_key)
user_peer_id = discord_user_peer_id(guild_id, discord_user_id)
mentions = event.get("mentions") or []
mentions_bot = bool(
bot_user_id
and any(str(mention.get("id")) == bot_user_id for mention in mentions)
)
return ParsedInboundEvent(
platform=self.platform,
event_kind=event_kind,
external_peer_id=thread_peer_id,
external_peer_kind="channel",
external_message_id=message_id,
external_user_id=discord_user_id,
text=text,
raw_payload=raw_payload,
display_name=event.get("channel_name"),
username=author.get("username") or discord_user_id,
metadata={
"guild_id": guild_id,
"channel_id": channel_id,
"discord_user_id": discord_user_id,
"message_id": message_id,
"thread_key": thread_key,
"bot_user_id": bot_user_id,
"discord_user_peer_id": user_peer_id,
"discord_thread_peer_id": thread_peer_id,
"mentions_bot": mentions_bot,
"is_dm": False,
},
)
async def send_message(
self,
*,
external_peer_id: str,
text: str,
parse_mode: str | None = None,
reply_to_message_id: str | None = None,
) -> PlatformSendResult:
del parse_mode
return await self.client.send_message(
channel_id=external_peer_id,
content=text,
reply_to_message_id=reply_to_message_id,
)
async def edit_message(
self,
*,
external_peer_id: str,
external_message_id: str,
text: str,
parse_mode: str | None = None,
) -> PlatformSendResult:
del parse_mode
return await self.client.update_message(
channel_id=external_peer_id,
message_id=external_message_id,
content=text,
)
async def validate_credentials(self) -> dict[str, Any]:
return await self.client.validate()

View file

@ -0,0 +1,109 @@
"""Discord REST API client for gateway bot operations."""
from __future__ import annotations
import asyncio
from typing import Any
import httpx
from app.gateway.base.adapter import PlatformSendResult
DISCORD_API = "https://discord.com/api/v10"
class DiscordGatewayClient:
def __init__(self, bot_token: str) -> None:
self.bot_token = bot_token
async def api_call(
self,
method: str,
path: str,
*,
payload: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
retry_rate_limit: bool = True,
) -> dict[str, Any]:
async with httpx.AsyncClient(timeout=20.0) as client:
response = await client.request(
method,
f"{DISCORD_API}{path}",
json=payload,
params=params,
headers={
"Authorization": f"Bot {self.bot_token}",
"Content-Type": "application/json",
},
)
if response.status_code == 429 and retry_rate_limit:
data = response.json()
retry_after = float(data.get("retry_after") or 1.0)
await asyncio.sleep(min(retry_after, 5.0))
return await self.api_call(
method,
path,
payload=payload,
params=params,
retry_rate_limit=False,
)
response.raise_for_status()
if not response.content:
return {}
return response.json()
async def send_message(
self,
*,
channel_id: str,
content: str,
reply_to_message_id: str | None = None,
) -> PlatformSendResult:
payload: dict[str, Any] = {
"content": content,
"allowed_mentions": {"parse": []},
}
if reply_to_message_id:
payload["message_reference"] = {
"message_id": reply_to_message_id,
"channel_id": channel_id,
"fail_if_not_exists": False,
}
data = await self.api_call(
"POST",
f"/channels/{channel_id}/messages",
payload=payload,
)
return PlatformSendResult(
external_message_id=str(data.get("id", "")),
raw_response=data,
)
async def update_message(
self,
*,
channel_id: str,
message_id: str,
content: str,
) -> PlatformSendResult:
data = await self.api_call(
"PATCH",
f"/channels/{channel_id}/messages/{message_id}",
payload={"content": content, "allowed_mentions": {"parse": []}},
)
return PlatformSendResult(
external_message_id=str(data.get("id") or message_id),
raw_response=data,
)
async def validate(self) -> dict[str, Any]:
data = await self.api_call("GET", "/users/@me")
return {
"ok": True,
"bot_user_id": data.get("id"),
"bot_username": data.get("username"),
"global_name": data.get("global_name"),
}
async def get_guild(self, guild_id: str) -> dict[str, Any]:
return await self.api_call("GET", f"/guilds/{guild_id}")

View file

@ -0,0 +1,92 @@
from __future__ import annotations
import pytest
from app.gateway.base.adapter import PlatformSendResult
from app.gateway.discord.adapter import DiscordAdapter
def _discord_payload(content: str = "<@999> summarize this channel"):
return {
"type": "message",
"bot_user_id": "999",
"event": {
"type": "message",
"id": "111",
"guild_id": "222",
"guild_name": "SurfSense Guild",
"channel_id": "333",
"channel_name": "general",
"content": content,
"author": {"id": "444", "username": "anish", "bot": False},
"mentions": [{"id": "999", "username": "SurfSense"}],
},
}
def test_discord_adapter_parses_mention_and_strips_bot_mention():
adapter = DiscordAdapter("discord-token", bot_user_id="999")
parsed = adapter.parse_inbound(_discord_payload())
assert parsed.platform == "discord"
assert parsed.text == "summarize this channel"
assert parsed.external_peer_id == "discord_thread:222:333:111"
assert parsed.metadata["discord_user_peer_id"] == "discord_user:222:444"
assert parsed.metadata["discord_thread_peer_id"] == "discord_thread:222:333:111"
assert parsed.metadata["mentions_bot"] is True
def test_discord_adapter_strips_nickname_mention():
adapter = DiscordAdapter("discord-token", bot_user_id="999")
parsed = adapter.parse_inbound(_discord_payload("<@!999> continue"))
assert parsed.text == "continue"
def test_discord_adapter_uses_message_reference_as_thread_key():
adapter = DiscordAdapter("discord-token", bot_user_id="999")
payload = _discord_payload("<@999> continue")
payload["event"]["id"] = "112"
payload["event"]["message_reference"] = {
"message_id": "111",
"channel_id": "333",
"guild_id": "222",
}
parsed = adapter.parse_inbound(payload)
assert parsed.external_peer_id == "discord_thread:222:333:111"
assert parsed.metadata["message_id"] == "112"
assert parsed.metadata["thread_key"] == "111"
def test_discord_adapter_returns_missing_peer_for_incomplete_payload():
adapter = DiscordAdapter("discord-token", bot_user_id="999")
parsed = adapter.parse_inbound({"event": {"id": "111"}})
assert parsed.external_peer_id is None
assert parsed.external_peer_kind == "unknown"
@pytest.mark.asyncio
async def test_discord_adapter_sends_message(mocker):
adapter = DiscordAdapter("discord-token", bot_user_id="999")
adapter.client.send_message = mocker.AsyncMock(
return_value=PlatformSendResult(external_message_id="555")
)
result = await adapter.send_message(
external_peer_id="333",
text="hello",
reply_to_message_id="111",
)
assert result.external_message_id == "555"
adapter.client.send_message.assert_awaited_once_with(
channel_id="333",
content="hello",
reply_to_message_id="111",
)