mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-01 11:56:25 +02:00
add Discord list channels, read messages, send message tools
This commit is contained in:
parent
07a5fac15d
commit
1de2517eae
5 changed files with 304 additions and 0 deletions
|
|
@ -0,0 +1,15 @@
|
||||||
|
from app.agents.new_chat.tools.discord.list_channels import (
|
||||||
|
create_list_discord_channels_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.discord.read_messages import (
|
||||||
|
create_read_discord_messages_tool,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.tools.discord.send_message import (
|
||||||
|
create_send_discord_message_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_list_discord_channels_tool",
|
||||||
|
"create_read_discord_messages_tool",
|
||||||
|
"create_send_discord_message_tool",
|
||||||
|
]
|
||||||
46
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
46
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DISCORD_API = "https://discord.com/api/v10"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_discord_connector(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> SearchSourceConnector | None:
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
SearchSourceConnector.user_id == user_id,
|
||||||
|
SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_bot_token(connector: SearchSourceConnector) -> str:
|
||||||
|
"""Extract and decrypt the bot token from connector config."""
|
||||||
|
cfg = dict(connector.config)
|
||||||
|
if cfg.get("_token_encrypted") and config.SECRET_KEY:
|
||||||
|
enc = TokenEncryption(config.SECRET_KEY)
|
||||||
|
if cfg.get("bot_token"):
|
||||||
|
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
|
||||||
|
token = cfg.get("bot_token")
|
||||||
|
if not token:
|
||||||
|
raise ValueError("Discord bot token not found in connector config.")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def get_guild_id(connector: SearchSourceConnector) -> str | None:
|
||||||
|
return connector.config.get("guild_id")
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_list_discord_channels_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def list_discord_channels() -> dict[str, Any]:
|
||||||
|
"""List text channels in the connected Discord server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of channels (id, name).
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
|
guild_id = get_guild_id(connector)
|
||||||
|
if not guild_id:
|
||||||
|
return {"status": "error", "message": "No guild ID in Discord connector config."}
|
||||||
|
|
||||||
|
token = get_bot_token(connector)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||||
|
headers={"Authorization": f"Bot {token}"},
|
||||||
|
timeout=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
# Type 0 = text channel
|
||||||
|
channels = [
|
||||||
|
{"id": ch["id"], "name": ch["name"]}
|
||||||
|
for ch in resp.json()
|
||||||
|
if ch.get("type") == 0
|
||||||
|
]
|
||||||
|
return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error listing Discord channels: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to list Discord channels."}
|
||||||
|
|
||||||
|
return list_discord_channels
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_read_discord_messages_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def read_discord_messages(
|
||||||
|
channel_id: str,
|
||||||
|
limit: int = 25,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Read recent messages from a Discord text channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The Discord channel ID (from list_discord_channels).
|
||||||
|
limit: Number of messages to fetch (default 25, max 50).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status and a list of messages including
|
||||||
|
id, author, content, timestamp.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||||
|
|
||||||
|
limit = min(limit, 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
|
token = get_bot_token(connector)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||||
|
headers={"Authorization": f"Bot {token}"},
|
||||||
|
params={"limit": limit},
|
||||||
|
timeout=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {"status": "error", "message": "Bot lacks permission to read this channel."}
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"id": m["id"],
|
||||||
|
"author": m.get("author", {}).get("username", "Unknown"),
|
||||||
|
"content": m.get("content", ""),
|
||||||
|
"timestamp": m.get("timestamp", ""),
|
||||||
|
}
|
||||||
|
for m in resp.json()
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error reading Discord messages: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to read Discord messages."}
|
||||||
|
|
||||||
|
return read_discord_messages
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.hitl import request_approval
|
||||||
|
|
||||||
|
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_send_discord_message_tool(
|
||||||
|
db_session: AsyncSession | None = None,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
@tool
|
||||||
|
async def send_discord_message(
|
||||||
|
channel_id: str,
|
||||||
|
content: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a message to a Discord text channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The Discord channel ID (from list_discord_channels).
|
||||||
|
content: The message text (max 2000 characters).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status, message_id on success.
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||||
|
"""
|
||||||
|
if db_session is None or search_space_id is None or user_id is None:
|
||||||
|
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||||
|
|
||||||
|
if len(content) > 2000:
|
||||||
|
return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||||
|
if not connector:
|
||||||
|
return {"status": "error", "message": "No Discord connector found."}
|
||||||
|
|
||||||
|
result = request_approval(
|
||||||
|
action_type="discord_send_message",
|
||||||
|
tool_name="send_discord_message",
|
||||||
|
params={"channel_id": channel_id, "content": content},
|
||||||
|
context={"connector_id": connector.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rejected:
|
||||||
|
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||||
|
|
||||||
|
final_content = result.params.get("content", content)
|
||||||
|
final_channel = result.params.get("channel_id", channel_id)
|
||||||
|
|
||||||
|
token = get_bot_token(connector)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bot {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={"content": final_content},
|
||||||
|
timeout=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 401:
|
||||||
|
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||||
|
if resp.status_code == 403:
|
||||||
|
return {"status": "error", "message": "Bot lacks permission to send messages in this channel."}
|
||||||
|
if resp.status_code not in (200, 201):
|
||||||
|
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||||
|
|
||||||
|
msg_data = resp.json()
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message_id": msg_data.get("id"),
|
||||||
|
"message": f"Message sent to channel {final_channel}.",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from langgraph.errors import GraphInterrupt
|
||||||
|
|
||||||
|
if isinstance(e, GraphInterrupt):
|
||||||
|
raise
|
||||||
|
logger.error("Error sending Discord message: %s", e, exc_info=True)
|
||||||
|
return {"status": "error", "message": "Failed to send Discord message."}
|
||||||
|
|
||||||
|
return send_discord_message
|
||||||
Loading…
Add table
Add a link
Reference in a new issue