feat(gateway): handle Discord channel replies

This commit is contained in:
Anish Sarkar 2026-06-01 20:59:04 +05:30
parent bc8a285187
commit 5024b69e69
2 changed files with 152 additions and 0 deletions

View file

@ -0,0 +1,66 @@
"""Discord command/onboarding handlers."""
from __future__ import annotations
from app.gateway.base.adapter import ParsedInboundEvent
from app.gateway.base.commands import BaseGatewayCommands
from app.gateway.discord.adapter import DiscordAdapter
from app.gateway.ratelimit import acquire_token
HELP_TEXT = (
"SurfSense Discord commands:\n"
"`/new` - start a fresh SurfSense conversation for this Discord thread\n"
"`/help` - show this help\n\n"
"Mention the SurfSense bot in a Discord channel to ask your agent a question. "
"Discord search remains controlled by the Discord connector in SurfSense."
)
class DiscordGatewayCommands(BaseGatewayCommands):
async def handle_help_command(
self,
*,
adapter: DiscordAdapter,
event: ParsedInboundEvent,
) -> bool:
channel_id = event.metadata.get("channel_id")
message_id = event.metadata.get("message_id")
if not channel_id:
return True
await adapter.send_message(
external_peer_id=channel_id,
text=HELP_TEXT,
reply_to_message_id=message_id,
)
return True
async def send_unbound_onboarding(
self,
*,
adapter: DiscordAdapter,
event: ParsedInboundEvent,
dashboard_url: str,
) -> None:
channel_id = event.metadata.get("channel_id")
message_id = event.metadata.get("message_id")
guild_id = event.metadata.get("guild_id")
discord_user_id = event.metadata.get("discord_user_id")
if not channel_id or not message_id:
return
wait_ms = await acquire_token(
f"discord:onboarded:{guild_id}:{discord_user_id}",
capacity=1,
refill_per_sec=1 / 3600,
)
if wait_ms > 0:
return
await adapter.send_message(
external_peer_id=channel_id,
reply_to_message_id=message_id,
text=(
"Hi! Connect your Discord user to SurfSense before using the bot here: "
f"{dashboard_url}"
),
)

View file

@ -0,0 +1,86 @@
"""Translate agent stream events into Discord replies."""
from __future__ import annotations
import logging
from collections.abc import AsyncIterator
from app.gateway.base.adapter import PlatformSendResult
from app.gateway.base.formatting import split_text_message
from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent
from app.gateway.discord.adapter import DiscordAdapter
from app.gateway.ratelimit import wait_for_token
from app.observability.metrics import (
record_gateway_hitl_aborted,
record_gateway_outbound,
record_gateway_rate_limit_hit,
)
logger = logging.getLogger(__name__)
DISCORD_MAX_MESSAGE_CHARS = 1900
HITL_UNSUPPORTED_MESSAGE = (
"This action requires approval and is not yet supported from Discord. "
"Try again with a different request."
)
class DiscordStreamTranslator(BaseStreamTranslator):
def __init__(
self,
*,
adapter: DiscordAdapter,
channel_id: str,
reply_to_message_id: str | None,
) -> None:
self.adapter = adapter
self.channel_id = channel_id
self.reply_to_message_id = reply_to_message_id
self._buffer = ""
async def translate(self, events: AsyncIterator[GatewayStreamEvent]) -> None:
async for event in events:
if event.type in {"text-delta", "text_delta", "text"}:
self._buffer += str(event.data.get("text") or event.data.get("delta") or "")
elif event.type in {"data-interrupt-request", "interrupt"}:
await self._handle_hitl_interrupt()
return
elif event.type in {"finish", "done"}:
break
await self._flush_final()
async def _flush_final(self) -> None:
if not self._buffer:
return
for chunk in split_text_message(self._buffer, max_chars=DISCORD_MAX_MESSAGE_CHARS):
await self._send_text(chunk)
async def _send_text(self, text: str) -> PlatformSendResult:
await self._throttle()
try:
result = await self.adapter.send_message(
external_peer_id=self.channel_id,
text=text,
reply_to_message_id=self.reply_to_message_id,
)
except Exception:
record_gateway_outbound(platform="discord", kind="send", status="failed")
raise
record_gateway_outbound(platform="discord", kind="send", status="sent")
return result
async def _throttle(self) -> None:
chat_wait = await wait_for_token(
f"discord:channel:{self.channel_id}",
capacity=5,
refill_per_sec=1.0,
)
if chat_wait:
record_gateway_rate_limit_hit(bucket="discord:channel")
async def _handle_hitl_interrupt(self) -> None:
if self._buffer:
await self._flush_final()
await self._send_text(HITL_UNSUPPORTED_MESSAGE)
record_gateway_hitl_aborted(platform="discord")