mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 17:26:23 +02:00
add Discord list channels, read messages, send message tools
This commit is contained in:
parent
07a5fac15d
commit
1de2517eae
5 changed files with 304 additions and 0 deletions
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.discord.list_channels import (
|
||||
create_list_discord_channels_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.read_messages import (
|
||||
create_read_discord_messages_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.send_message import (
|
||||
create_send_discord_message_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_list_discord_channels_tool",
|
||||
"create_read_discord_messages_tool",
|
||||
"create_send_discord_message_tool",
|
||||
]
|
||||
46
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
46
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DISCORD_API = "https://discord.com/api/v10"
|
||||
|
||||
|
||||
async def get_discord_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def get_bot_token(connector: SearchSourceConnector) -> str:
|
||||
"""Extract and decrypt the bot token from connector config."""
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted") and config.SECRET_KEY:
|
||||
enc = TokenEncryption(config.SECRET_KEY)
|
||||
if cfg.get("bot_token"):
|
||||
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
|
||||
token = cfg.get("bot_token")
|
||||
if not token:
|
||||
raise ValueError("Discord bot token not found in connector config.")
|
||||
return token
|
||||
|
||||
|
||||
def get_guild_id(connector: SearchSourceConnector) -> str | None:
|
||||
return connector.config.get("guild_id")
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_discord_channels_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_discord_channels() -> dict[str, Any]:
|
||||
"""List text channels in the connected Discord server.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
guild_id = get_guild_id(connector)
|
||||
if not guild_id:
|
||||
return {"status": "error", "message": "No guild ID in Discord connector config."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
# Type 0 = text channel
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch["name"]}
|
||||
for ch in resp.json()
|
||||
if ch.get("type") == 0
|
||||
]
|
||||
return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Discord channels: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Discord channels."}
|
||||
|
||||
return list_discord_channels
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_discord_messages_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_discord_messages(
|
||||
channel_id: str,
|
||||
limit: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Read recent messages from a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
limit: Number of messages to fetch (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of messages including
|
||||
id, author, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
params={"limit": limit},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Bot lacks permission to read this channel."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"id": m["id"],
|
||||
"author": m.get("author", {}).get("username", "Unknown"),
|
||||
"content": m.get("content", ""),
|
||||
"timestamp": m.get("timestamp", ""),
|
||||
}
|
||||
for m in resp.json()
|
||||
]
|
||||
|
||||
return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Discord messages: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Discord messages."}
|
||||
|
||||
return read_discord_messages
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_discord_message_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_discord_message(
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
content: The message text (max 2000 characters).
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
if len(content) > 2000:
|
||||
return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="discord_send_message",
|
||||
tool_name="send_discord_message",
|
||||
params={"channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bot {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"content": final_content},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Bot lacks permission to send messages in this channel."}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to channel {final_channel}.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error sending Discord message: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to send Discord message."}
|
||||
|
||||
return send_discord_message
|
||||
Loading…
Add table
Add a link
Reference in a new issue