mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
feat(gateway): add Discord mention intake supervisor
This commit is contained in:
parent
f8ff58bdce
commit
05eaa46c3a
2 changed files with 207 additions and 0 deletions
|
|
@ -41,6 +41,10 @@ from app.gateway.byo_long_poll import (
|
||||||
start_byo_long_poll_supervisors,
|
start_byo_long_poll_supervisors,
|
||||||
stop_byo_long_poll_supervisors,
|
stop_byo_long_poll_supervisors,
|
||||||
)
|
)
|
||||||
|
from app.gateway.discord.intake import (
|
||||||
|
start_discord_gateway_supervisor,
|
||||||
|
stop_discord_gateway_supervisor,
|
||||||
|
)
|
||||||
from app.gateway.inbox_worker import (
|
from app.gateway.inbox_worker import (
|
||||||
start_gateway_inbox_worker,
|
start_gateway_inbox_worker,
|
||||||
stop_gateway_inbox_worker,
|
stop_gateway_inbox_worker,
|
||||||
|
|
@ -607,10 +611,12 @@ async def lifespan(app: FastAPI):
|
||||||
log_system_snapshot("startup_complete")
|
log_system_snapshot("startup_complete")
|
||||||
await start_gateway_inbox_worker()
|
await start_gateway_inbox_worker()
|
||||||
await start_byo_long_poll_supervisors()
|
await start_byo_long_poll_supervisors()
|
||||||
|
await start_discord_gateway_supervisor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
await stop_discord_gateway_supervisor()
|
||||||
await stop_byo_long_poll_supervisors()
|
await stop_byo_long_poll_supervisors()
|
||||||
await stop_gateway_inbox_worker()
|
await stop_gateway_inbox_worker()
|
||||||
_stop_openrouter_background_refresh()
|
_stop_openrouter_background_refresh()
|
||||||
|
|
|
||||||
201
surfsense_backend/app/gateway/discord/intake.py
Normal file
201
surfsense_backend/app/gateway/discord/intake.py
Normal file
|
|
@ -0,0 +1,201 @@
|
||||||
|
"""FastAPI lifespan supervisor for Discord Gateway WebSocket intake."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from contextlib import suppress
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import ExternalChatPlatform, async_session_maker
|
||||||
|
from app.gateway.accounts import get_discord_account_by_guild
|
||||||
|
from app.gateway.inbox import discord_message_dedupe_key, persist_inbound_event
|
||||||
|
from app.observability.metrics import record_gateway_inbox_write
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_task: asyncio.Task[None] | None = None
|
||||||
|
_client: discord.Client | None = None
|
||||||
|
_shutdown_event: asyncio.Event | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _message_reference_payload(message: discord.Message) -> dict[str, Any] | None:
|
||||||
|
if message.reference is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"message_id": str(message.reference.message_id)
|
||||||
|
if message.reference.message_id
|
||||||
|
else None,
|
||||||
|
"channel_id": str(message.reference.channel_id)
|
||||||
|
if message.reference.channel_id
|
||||||
|
else None,
|
||||||
|
"guild_id": str(message.reference.guild_id)
|
||||||
|
if message.reference.guild_id
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_message(message: discord.Message, *, bot_user_id: str | None) -> dict[str, Any]:
|
||||||
|
guild = message.guild
|
||||||
|
channel = message.channel
|
||||||
|
thread_id = str(channel.id) if isinstance(channel, discord.Thread) else None
|
||||||
|
parent_id = str(channel.parent_id) if isinstance(channel, discord.Thread) else None
|
||||||
|
return {
|
||||||
|
"type": "message",
|
||||||
|
"bot_user_id": bot_user_id,
|
||||||
|
"event": {
|
||||||
|
"type": "message",
|
||||||
|
"id": str(message.id),
|
||||||
|
"guild_id": str(guild.id) if guild else None,
|
||||||
|
"guild_name": guild.name if guild else None,
|
||||||
|
"channel_id": parent_id or str(message.channel.id),
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"channel_name": getattr(channel, "name", None),
|
||||||
|
"content": message.content,
|
||||||
|
"author": {
|
||||||
|
"id": str(message.author.id),
|
||||||
|
"username": message.author.name,
|
||||||
|
"bot": message.author.bot,
|
||||||
|
},
|
||||||
|
"mentions": [
|
||||||
|
{"id": str(user.id), "username": user.name}
|
||||||
|
for user in message.mentions
|
||||||
|
],
|
||||||
|
"message_reference": _message_reference_payload(message),
|
||||||
|
"created_at": message.created_at.isoformat()
|
||||||
|
if message.created_at
|
||||||
|
else None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _persist_message(message: discord.Message, *, bot_user_id: str | None) -> None:
|
||||||
|
if message.guild is None:
|
||||||
|
return
|
||||||
|
guild_id = str(message.guild.id)
|
||||||
|
raw_payload = _serialize_message(message, bot_user_id=bot_user_id)
|
||||||
|
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
account = await get_discord_account_by_guild(session, guild_id=guild_id)
|
||||||
|
if account is None:
|
||||||
|
logger.info("Ignoring Discord message for uninstalled guild_id=%s", guild_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
inbox_id = await persist_inbound_event(
|
||||||
|
session,
|
||||||
|
account_id=account.id,
|
||||||
|
platform=ExternalChatPlatform.DISCORD,
|
||||||
|
event_dedupe_key=discord_message_dedupe_key(message.id),
|
||||||
|
external_event_id=str(message.id),
|
||||||
|
external_message_id=str(message.id),
|
||||||
|
event_kind="message",
|
||||||
|
raw_payload=raw_payload,
|
||||||
|
request_id=f"gateway_{uuid.uuid4().hex[:16]}",
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
record_gateway_inbox_write(platform="discord", dedup_skipped=inbox_id is None)
|
||||||
|
logger.info(
|
||||||
|
"Persisted Discord gateway message_id=%s guild_id=%s inbox_id=%s",
|
||||||
|
message.id,
|
||||||
|
guild_id,
|
||||||
|
inbox_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_client() -> discord.Client:
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.guilds = True
|
||||||
|
intents.messages = True
|
||||||
|
intents.message_content = True
|
||||||
|
client = discord.Client(intents=intents)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready() -> None:
|
||||||
|
logger.info(
|
||||||
|
"Discord gateway connected as %s (%s)",
|
||||||
|
client.user,
|
||||||
|
getattr(client.user, "id", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_message(message: discord.Message) -> None:
|
||||||
|
if message.author.bot:
|
||||||
|
return
|
||||||
|
bot_user = client.user
|
||||||
|
if bot_user is None:
|
||||||
|
return
|
||||||
|
if message.author.id == bot_user.id:
|
||||||
|
return
|
||||||
|
bot_user_id = str(bot_user.id)
|
||||||
|
mention_ids = {str(user.id) for user in message.mentions}
|
||||||
|
if bot_user_id not in mention_ids:
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"Received Discord gateway mention message_id=%s guild_id=%s channel_id=%s content_present=%s",
|
||||||
|
message.id,
|
||||||
|
getattr(message.guild, "id", None),
|
||||||
|
getattr(message.channel, "id", None),
|
||||||
|
bool(message.content),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await _persist_message(message, bot_user_id=bot_user_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Discord gateway failed to persist message_id=%s", message.id)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_discord_gateway() -> None:
|
||||||
|
global _client
|
||||||
|
token = config.DISCORD_BOT_TOKEN
|
||||||
|
if not token:
|
||||||
|
logger.warning("Discord gateway enabled but DISCORD_BOT_TOKEN is not set")
|
||||||
|
return
|
||||||
|
|
||||||
|
while _shutdown_event is None or not _shutdown_event.is_set():
|
||||||
|
_client = _build_client()
|
||||||
|
try:
|
||||||
|
await _client.start(token)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Discord gateway WebSocket failed; retrying in 30s")
|
||||||
|
finally:
|
||||||
|
if _client is not None and not _client.is_closed():
|
||||||
|
await _client.close()
|
||||||
|
if _shutdown_event is not None and _shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(_shutdown_event.wait(), timeout=30.0)
|
||||||
|
except (TimeoutError, AttributeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
async def start_discord_gateway_supervisor() -> None:
|
||||||
|
global _shutdown_event, _task
|
||||||
|
if not config.GATEWAY_DISCORD_ENABLED:
|
||||||
|
return
|
||||||
|
if _task is not None and not _task.done():
|
||||||
|
return
|
||||||
|
_shutdown_event = asyncio.Event()
|
||||||
|
_task = asyncio.create_task(_run_discord_gateway(), name="gateway-discord-intake")
|
||||||
|
logger.info("Started Discord gateway intake supervisor")
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_discord_gateway_supervisor() -> None:
|
||||||
|
global _client, _shutdown_event, _task
|
||||||
|
if _shutdown_event is not None:
|
||||||
|
_shutdown_event.set()
|
||||||
|
if _client is not None and not _client.is_closed():
|
||||||
|
await _client.close()
|
||||||
|
if _task is not None:
|
||||||
|
_task.cancel()
|
||||||
|
with suppress(TimeoutError, asyncio.CancelledError):
|
||||||
|
await asyncio.wait_for(_task, timeout=10)
|
||||||
|
_client = None
|
||||||
|
_task = None
|
||||||
|
_shutdown_event = None
|
||||||
Loading…
Add table
Add a link
Reference in a new issue