mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): colocate discord connector tools into subagent slice
Repoint the dead tools/__init__ shim at the live local impls and delete the dead shared/tools/discord twin (subagent already ran its local copies via tools/index.py). No runtime behavior change.
This commit is contained in:
parent
425e6e50a3
commit
c6525c4f52
6 changed files with 3 additions and 430 deletions
|
|
@ -1,12 +1,6 @@
|
||||||
from app.agents.shared.tools.discord.list_channels import (
|
from .list_channels import create_list_discord_channels_tool
|
||||||
create_list_discord_channels_tool,
|
from .read_messages import create_read_discord_messages_tool
|
||||||
)
|
from .send_message import create_send_discord_message_tool
|
||||||
from app.agents.shared.tools.discord.read_messages import (
|
|
||||||
create_read_discord_messages_tool,
|
|
||||||
)
|
|
||||||
from app.agents.shared.tools.discord.send_message import (
|
|
||||||
create_send_discord_message_tool,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_list_discord_channels_tool",
|
"create_list_discord_channels_tool",
|
||||||
|
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
from app.agents.shared.tools.discord.list_channels import (
|
|
||||||
create_list_discord_channels_tool,
|
|
||||||
)
|
|
||||||
from app.agents.shared.tools.discord.read_messages import (
|
|
||||||
create_read_discord_messages_tool,
|
|
||||||
)
|
|
||||||
from app.agents.shared.tools.discord.send_message import (
|
|
||||||
create_send_discord_message_tool,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"create_list_discord_channels_tool",
|
|
||||||
"create_read_discord_messages_tool",
|
|
||||||
"create_send_discord_message_tool",
|
|
||||||
]
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
||||||
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
|
||||||
from app.utils.oauth_security import TokenEncryption
|
|
||||||
|
|
||||||
DISCORD_API = "https://discord.com/api/v10"
|
|
||||||
|
|
||||||
|
|
||||||
async def get_discord_connector(
|
|
||||||
db_session: AsyncSession,
|
|
||||||
search_space_id: int,
|
|
||||||
user_id: str,
|
|
||||||
) -> SearchSourceConnector | None:
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(SearchSourceConnector).filter(
|
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
|
||||||
SearchSourceConnector.user_id == user_id,
|
|
||||||
SearchSourceConnector.connector_type
|
|
||||||
== SearchSourceConnectorType.DISCORD_CONNECTOR,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalars().first()
|
|
||||||
|
|
||||||
|
|
||||||
def get_bot_token(connector: SearchSourceConnector) -> str:
|
|
||||||
"""Extract and decrypt the bot token from connector config."""
|
|
||||||
cfg = dict(connector.config)
|
|
||||||
if cfg.get("_token_encrypted") and config.SECRET_KEY:
|
|
||||||
enc = TokenEncryption(config.SECRET_KEY)
|
|
||||||
if cfg.get("bot_token"):
|
|
||||||
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
|
|
||||||
token = cfg.get("bot_token")
|
|
||||||
if not token:
|
|
||||||
raise ValueError("Discord bot token not found in connector config.")
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
def get_guild_id(connector: SearchSourceConnector) -> str | None:
|
|
||||||
return connector.config.get("guild_id")
|
|
||||||
|
|
@ -1,107 +0,0 @@
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.db import async_session_maker
|
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_list_discord_channels_tool(
|
|
||||||
db_session: AsyncSession | None = None,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the list_discord_channels tool.
|
|
||||||
|
|
||||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
|
||||||
:data:`async_session_maker` so the closure is safe to share across
|
|
||||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
|
||||||
session here would surface stale/closed sessions on cache hits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_session: Reserved for registry compatibility. Per-call sessions
|
|
||||||
are opened via :data:`async_session_maker` inside the tool body.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured list_discord_channels tool
|
|
||||||
"""
|
|
||||||
del db_session # per-call session — see docstring
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_discord_channels() -> dict[str, Any]:
|
|
||||||
"""List text channels in the connected Discord server.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with status and a list of channels (id, name).
|
|
||||||
"""
|
|
||||||
if search_space_id is None or user_id is None:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Discord tool not properly configured.",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with async_session_maker() as db_session:
|
|
||||||
connector = await get_discord_connector(
|
|
||||||
db_session, search_space_id, user_id
|
|
||||||
)
|
|
||||||
if not connector:
|
|
||||||
return {"status": "error", "message": "No Discord connector found."}
|
|
||||||
|
|
||||||
guild_id = get_guild_id(connector)
|
|
||||||
if not guild_id:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "No guild ID in Discord connector config.",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = get_bot_token(connector)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.get(
|
|
||||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
|
||||||
headers={"Authorization": f"Bot {token}"},
|
|
||||||
timeout=15.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp.status_code == 401:
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "Discord bot token is invalid.",
|
|
||||||
"connector_type": "discord",
|
|
||||||
}
|
|
||||||
if resp.status_code != 200:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Discord API error: {resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Type 0 = text channel
|
|
||||||
channels = [
|
|
||||||
{"id": ch["id"], "name": ch["name"]}
|
|
||||||
for ch in resp.json()
|
|
||||||
if ch.get("type") == 0
|
|
||||||
]
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"guild_id": guild_id,
|
|
||||||
"channels": channels,
|
|
||||||
"total": len(channels),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
from langgraph.errors import GraphInterrupt
|
|
||||||
|
|
||||||
if isinstance(e, GraphInterrupt):
|
|
||||||
raise
|
|
||||||
logger.error("Error listing Discord channels: %s", e, exc_info=True)
|
|
||||||
return {"status": "error", "message": "Failed to list Discord channels."}
|
|
||||||
|
|
||||||
return list_discord_channels
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.db import async_session_maker
|
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_read_discord_messages_tool(
|
|
||||||
db_session: AsyncSession | None = None,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the read_discord_messages tool.
|
|
||||||
|
|
||||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
|
||||||
:data:`async_session_maker` so the closure is safe to share across
|
|
||||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
|
||||||
session here would surface stale/closed sessions on cache hits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_session: Reserved for registry compatibility. Per-call sessions
|
|
||||||
are opened via :data:`async_session_maker` inside the tool body.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured read_discord_messages tool
|
|
||||||
"""
|
|
||||||
del db_session # per-call session — see docstring
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def read_discord_messages(
|
|
||||||
channel_id: str,
|
|
||||||
limit: int = 25,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Read recent messages from a Discord text channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The Discord channel ID (from list_discord_channels).
|
|
||||||
limit: Number of messages to fetch (default 25, max 50).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with status and a list of messages including
|
|
||||||
id, author, content, timestamp.
|
|
||||||
"""
|
|
||||||
if search_space_id is None or user_id is None:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Discord tool not properly configured.",
|
|
||||||
}
|
|
||||||
|
|
||||||
limit = min(limit, 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with async_session_maker() as db_session:
|
|
||||||
connector = await get_discord_connector(
|
|
||||||
db_session, search_space_id, user_id
|
|
||||||
)
|
|
||||||
if not connector:
|
|
||||||
return {"status": "error", "message": "No Discord connector found."}
|
|
||||||
|
|
||||||
token = get_bot_token(connector)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.get(
|
|
||||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
|
||||||
headers={"Authorization": f"Bot {token}"},
|
|
||||||
params={"limit": limit},
|
|
||||||
timeout=15.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp.status_code == 401:
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "Discord bot token is invalid.",
|
|
||||||
"connector_type": "discord",
|
|
||||||
}
|
|
||||||
if resp.status_code == 403:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Bot lacks permission to read this channel.",
|
|
||||||
}
|
|
||||||
if resp.status_code != 200:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Discord API error: {resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"id": m["id"],
|
|
||||||
"author": m.get("author", {}).get("username", "Unknown"),
|
|
||||||
"content": m.get("content", ""),
|
|
||||||
"timestamp": m.get("timestamp", ""),
|
|
||||||
}
|
|
||||||
for m in resp.json()
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"channel_id": channel_id,
|
|
||||||
"messages": messages,
|
|
||||||
"total": len(messages),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
from langgraph.errors import GraphInterrupt
|
|
||||||
|
|
||||||
if isinstance(e, GraphInterrupt):
|
|
||||||
raise
|
|
||||||
logger.error("Error reading Discord messages: %s", e, exc_info=True)
|
|
||||||
return {"status": "error", "message": "Failed to read Discord messages."}
|
|
||||||
|
|
||||||
return read_discord_messages
|
|
||||||
|
|
@ -1,136 +0,0 @@
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.agents.shared.tools.hitl import request_approval
|
|
||||||
from app.db import async_session_maker
|
|
||||||
|
|
||||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_send_discord_message_tool(
|
|
||||||
db_session: AsyncSession | None = None,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the send_discord_message tool.
|
|
||||||
|
|
||||||
The tool acquires its own short-lived ``AsyncSession`` per call via
|
|
||||||
:data:`async_session_maker` so the closure is safe to share across
|
|
||||||
HTTP requests by the compiled-agent cache. Capturing a per-request
|
|
||||||
session here would surface stale/closed sessions on cache hits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_session: Reserved for registry compatibility. Per-call sessions
|
|
||||||
are opened via :data:`async_session_maker` inside the tool body.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured send_discord_message tool
|
|
||||||
"""
|
|
||||||
del db_session # per-call session — see docstring
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def send_discord_message(
|
|
||||||
channel_id: str,
|
|
||||||
content: str,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Send a message to a Discord text channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The Discord channel ID (from list_discord_channels).
|
|
||||||
content: The message text (max 2000 characters).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with status, message_id on success.
|
|
||||||
|
|
||||||
IMPORTANT:
|
|
||||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
|
||||||
"""
|
|
||||||
if search_space_id is None or user_id is None:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Discord tool not properly configured.",
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(content) > 2000:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Message exceeds Discord's 2000-character limit.",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with async_session_maker() as db_session:
|
|
||||||
connector = await get_discord_connector(
|
|
||||||
db_session, search_space_id, user_id
|
|
||||||
)
|
|
||||||
if not connector:
|
|
||||||
return {"status": "error", "message": "No Discord connector found."}
|
|
||||||
|
|
||||||
result = request_approval(
|
|
||||||
action_type="discord_send_message",
|
|
||||||
tool_name="send_discord_message",
|
|
||||||
params={"channel_id": channel_id, "content": content},
|
|
||||||
context={"connector_id": connector.id},
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.rejected:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"message": "User declined. Message was not sent.",
|
|
||||||
}
|
|
||||||
|
|
||||||
final_content = result.params.get("content", content)
|
|
||||||
final_channel = result.params.get("channel_id", channel_id)
|
|
||||||
|
|
||||||
token = get_bot_token(connector)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.post(
|
|
||||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bot {token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
json={"content": final_content},
|
|
||||||
timeout=15.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp.status_code == 401:
|
|
||||||
return {
|
|
||||||
"status": "auth_error",
|
|
||||||
"message": "Discord bot token is invalid.",
|
|
||||||
"connector_type": "discord",
|
|
||||||
}
|
|
||||||
if resp.status_code == 403:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": "Bot lacks permission to send messages in this channel.",
|
|
||||||
}
|
|
||||||
if resp.status_code not in (200, 201):
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Discord API error: {resp.status_code}",
|
|
||||||
}
|
|
||||||
|
|
||||||
msg_data = resp.json()
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message_id": msg_data.get("id"),
|
|
||||||
"message": f"Message sent to channel {final_channel}.",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
from langgraph.errors import GraphInterrupt
|
|
||||||
|
|
||||||
if isinstance(e, GraphInterrupt):
|
|
||||||
raise
|
|
||||||
logger.error("Error sending Discord message: %s", e, exc_info=True)
|
|
||||||
return {"status": "error", "message": "Failed to send Discord message."}
|
|
||||||
|
|
||||||
return send_discord_message
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue