feat(gateway): route Discord events through external chat

This commit is contained in:
Anish Sarkar 2026-06-01 20:59:13 +05:30
parent 5024b69e69
commit f8ff58bdce
2 changed files with 121 additions and 6 deletions

View file

@ -136,6 +136,8 @@ async def _resolve_binding_for_event(
) -> ExternalChatBinding | None:
if account.platform == ExternalChatPlatform.SLACK:
return await _resolve_slack_thread_binding(session, account, parsed)
if account.platform == ExternalChatPlatform.DISCORD:
return await _resolve_discord_thread_binding(session, account, parsed)
result = await session.execute(
select(ExternalChatBinding).where(
@ -209,6 +211,74 @@ async def _resolve_slack_thread_binding(
return thread_binding
async def _resolve_discord_thread_binding(
session: AsyncSession,
account: ExternalChatAccount,
parsed,
) -> ExternalChatBinding | None:
user_peer_id = parsed.metadata.get("discord_user_peer_id")
thread_peer_id = parsed.metadata.get("discord_thread_peer_id") or parsed.external_peer_id
if not user_peer_id or not thread_peer_id:
return None
user_result = await session.execute(
select(ExternalChatBinding).where(
ExternalChatBinding.account_id == account.id,
ExternalChatBinding.external_peer_id == user_peer_id,
ExternalChatBinding.state.in_(
[ExternalChatBindingState.BOUND, ExternalChatBindingState.SUSPENDED]
),
)
)
user_binding = user_result.scalars().first()
if user_binding is None:
return None
thread_result = await session.execute(
select(ExternalChatBinding).where(
ExternalChatBinding.account_id == account.id,
ExternalChatBinding.external_peer_id == thread_peer_id,
ExternalChatBinding.state.in_(
[ExternalChatBindingState.BOUND, ExternalChatBindingState.SUSPENDED]
),
)
)
thread_binding = thread_result.scalars().first()
if thread_binding is not None:
return thread_binding
thread_binding = ExternalChatBinding(
account_id=account.id,
user_id=user_binding.user_id,
search_space_id=user_binding.search_space_id,
state=ExternalChatBindingState.BOUND,
external_peer_id=thread_peer_id,
external_peer_kind=ExternalChatPeerKind.CHANNEL,
external_thread_id=parsed.metadata.get("thread_key"),
external_display_name=parsed.metadata.get("channel_id"),
external_username=parsed.external_user_id,
external_metadata={
"kind": "discord_thread",
"guild_id": parsed.metadata.get("guild_id"),
"channel_id": parsed.metadata.get("channel_id"),
"thread_key": parsed.metadata.get("thread_key"),
"discord_user_id": parsed.metadata.get("discord_user_id"),
"user_binding_id": user_binding.id,
},
)
session.add(thread_binding)
await session.flush()
return thread_binding
def _reply_target(parsed) -> tuple[str | None, str | None]:
if parsed.platform == "slack":
return parsed.metadata.get("channel_id"), parsed.metadata.get("thread_ts")
if parsed.platform == "discord":
return parsed.metadata.get("channel_id"), parsed.metadata.get("message_id")
return parsed.external_peer_id, None
async def _dispatch_inbound_event(
inbox_id: int,
session_maker: SessionMaker,
@ -245,7 +315,8 @@ async def _dispatch_inbound_event(
binding = await _resolve_binding_for_event(session, account, parsed)
if (
account.platform != ExternalChatPlatform.SLACK
account.platform
not in {ExternalChatPlatform.SLACK, ExternalChatPlatform.DISCORD}
and parsed.external_peer_kind != ExternalChatPeerKind.DIRECT.value
):
if hasattr(adapter, "leave_chat"):
@ -300,10 +371,13 @@ async def _dispatch_inbound_event(
return
if cmd == "/new":
binding.new_chat_thread_id = None
await adapter.send_message(
external_peer_id=parsed.external_peer_id,
text="Started a new SurfSense conversation.",
)
reply_peer_id, reply_message_id = _reply_target(parsed)
if reply_peer_id:
await adapter.send_message(
external_peer_id=reply_peer_id,
text="Started a new SurfSense conversation.",
reply_to_message_id=reply_message_id,
)
event.status = ExternalChatEventStatus.PROCESSED
await session.commit()
return

View file

@ -6,7 +6,11 @@ from collections.abc import Callable
from dataclasses import dataclass
from app.db import ExternalChatAccount, ExternalChatAccountMode, ExternalChatPlatform
from app.gateway.accounts import account_token, slack_account_credentials
from app.gateway.accounts import (
account_token,
discord_account_credentials,
slack_account_credentials,
)
from app.gateway.base.adapter import BasePlatformAdapter, ParsedInboundEvent
from app.gateway.base.commands import BaseGatewayCommands
from app.gateway.base.translator import BaseStreamTranslator
@ -87,6 +91,23 @@ def _slack_translator_factory(
)
def _discord_translator_factory(
adapter: BasePlatformAdapter,
event: ParsedInboundEvent,
) -> BaseStreamTranslator:
channel_id = event.metadata.get("channel_id")
message_id = event.metadata.get("message_id")
if not channel_id:
raise RuntimeError("missing_discord_channel_metadata")
from app.gateway.discord.translator import DiscordStreamTranslator
return DiscordStreamTranslator(
adapter=adapter, # type: ignore[arg-type]
channel_id=channel_id,
reply_to_message_id=message_id,
)
def resolve_platform_bundle(account: ExternalChatAccount) -> PlatformBundle:
if account.platform == ExternalChatPlatform.TELEGRAM:
token = account_token(account)
@ -145,4 +166,24 @@ def resolve_platform_bundle(account: ExternalChatAccount) -> PlatformBundle:
auto_bind_owner=False,
)
if account.platform == ExternalChatPlatform.DISCORD:
from app.gateway.discord.adapter import DiscordAdapter
from app.gateway.discord.commands import DiscordGatewayCommands
credentials = discord_account_credentials(account)
bot_token = credentials.get("bot_token")
if not bot_token:
raise RuntimeError("missing_discord_bot_token")
cursor_state = account.cursor_state or {}
return PlatformBundle(
adapter=DiscordAdapter(
bot_token,
bot_user_id=cursor_state.get("bot_user_id"),
),
translator_factory=_discord_translator_factory,
platform_label="discord",
commands=DiscordGatewayCommands(),
auto_bind_owner=False,
)
raise RuntimeError(f"unsupported_gateway_platform:{account.platform.value}:{account.mode.value}")